summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:46:47 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:46:47 +0900
commit69e88b3b353f1a2ab7a78259b480e8afbd87669c (patch)
tree712de7c4dd0bd16ba853a77d012ebed7c57d91c7
parent678005dc51892c4c1f4cea2730bbf0ec4ebc312d (diff)
fix: snapshot delay, env fields, alembic creds, API healthcheck and error handling
-rw-r--r--docker-compose.yml5
-rw-r--r--services/api/src/trading_api/routers/orders.py90
-rw-r--r--services/api/src/trading_api/routers/portfolio.py64
-rw-r--r--services/api/src/trading_api/routers/strategies.py33
-rw-r--r--services/portfolio-manager/src/portfolio_manager/main.py9
-rw-r--r--shared/alembic.ini2
-rw-r--r--shared/tests/test_db.py1
7 files changed, 123 insertions, 81 deletions
diff --git a/docker-compose.yml b/docker-compose.yml
index 1b72e8d..e981f74 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -115,6 +115,11 @@ services:
condition: service_healthy
postgres:
condition: service_healthy
+ healthcheck:
+ test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
+ interval: 10s
+ timeout: 5s
+ retries: 3
restart: unless-stopped
loki:
diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py
index d0b9fa6..c69dc10 100644
--- a/services/api/src/trading_api/routers/orders.py
+++ b/services/api/src/trading_api/routers/orders.py
@@ -1,55 +1,67 @@
"""Order endpoints."""
-from fastapi import APIRouter, Request
+import logging
+
+from fastapi import APIRouter, HTTPException, Request
from shared.sa_models import OrderRow, SignalRow
from sqlalchemy import select
+logger = logging.getLogger(__name__)
+
router = APIRouter()
@router.get("/")
async def get_orders(request: Request, limit: int = 50):
"""Get recent orders."""
- db = request.app.state.db
- async with db.get_session() as session:
- stmt = select(OrderRow).order_by(OrderRow.created_at.desc()).limit(limit)
- result = await session.execute(stmt)
- rows = result.scalars().all()
- return [
- {
- "id": r.id,
- "signal_id": r.signal_id,
- "symbol": r.symbol,
- "side": r.side,
- "type": r.type,
- "price": float(r.price),
- "quantity": float(r.quantity),
- "status": r.status,
- "created_at": r.created_at.isoformat() if r.created_at else None,
- "filled_at": r.filled_at.isoformat() if r.filled_at else None,
- }
- for r in rows
- ]
+ try:
+ db = request.app.state.db
+ async with db.get_session() as session:
+ stmt = select(OrderRow).order_by(OrderRow.created_at.desc()).limit(limit)
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ return [
+ {
+ "id": r.id,
+ "signal_id": r.signal_id,
+ "symbol": r.symbol,
+ "side": r.side,
+ "type": r.type,
+ "price": float(r.price),
+ "quantity": float(r.quantity),
+ "status": r.status,
+ "created_at": r.created_at.isoformat() if r.created_at else None,
+ "filled_at": r.filled_at.isoformat() if r.filled_at else None,
+ }
+ for r in rows
+ ]
+ except Exception as exc:
+ logger.error("Failed to get orders: %s", exc)
+ raise HTTPException(status_code=500, detail="Failed to retrieve orders")
@router.get("/signals")
async def get_signals(request: Request, limit: int = 50):
"""Get recent signals."""
- db = request.app.state.db
- async with db.get_session() as session:
- stmt = select(SignalRow).order_by(SignalRow.created_at.desc()).limit(limit)
- result = await session.execute(stmt)
- rows = result.scalars().all()
- return [
- {
- "id": r.id,
- "strategy": r.strategy,
- "symbol": r.symbol,
- "side": r.side,
- "price": float(r.price),
- "quantity": float(r.quantity),
- "reason": r.reason,
- "created_at": r.created_at.isoformat() if r.created_at else None,
- }
- for r in rows
- ]
+ try:
+ db = request.app.state.db
+ async with db.get_session() as session:
+ stmt = select(SignalRow).order_by(SignalRow.created_at.desc()).limit(limit)
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ return [
+ {
+ "id": r.id,
+ "strategy": r.strategy,
+ "symbol": r.symbol,
+ "side": r.side,
+ "price": float(r.price),
+ "quantity": float(r.quantity),
+ "reason": r.reason,
+ "created_at": r.created_at.isoformat() if r.created_at else None,
+ }
+ for r in rows
+ ]
+ except Exception as exc:
+ logger.error("Failed to get signals: %s", exc)
+ raise HTTPException(status_code=500, detail="Failed to retrieve signals")
diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py
index 3b30e1d..d76d85d 100644
--- a/services/api/src/trading_api/routers/portfolio.py
+++ b/services/api/src/trading_api/routers/portfolio.py
@@ -1,42 +1,54 @@
"""Portfolio endpoints."""
-from fastapi import APIRouter, Request
+import logging
+
+from fastapi import APIRouter, HTTPException, Request
from shared.sa_models import PositionRow
from sqlalchemy import select
+logger = logging.getLogger(__name__)
+
router = APIRouter()
@router.get("/positions")
async def get_positions(request: Request):
"""Get all current positions."""
- db = request.app.state.db
- async with db.get_session() as session:
- result = await session.execute(select(PositionRow))
- rows = result.scalars().all()
- return [
- {
- "symbol": r.symbol,
- "quantity": float(r.quantity),
- "avg_entry_price": float(r.avg_entry_price),
- "current_price": float(r.current_price),
- "unrealized_pnl": float(r.quantity * (r.current_price - r.avg_entry_price)),
- }
- for r in rows
- ]
+ try:
+ db = request.app.state.db
+ async with db.get_session() as session:
+ result = await session.execute(select(PositionRow))
+ rows = result.scalars().all()
+ return [
+ {
+ "symbol": r.symbol,
+ "quantity": float(r.quantity),
+ "avg_entry_price": float(r.avg_entry_price),
+ "current_price": float(r.current_price),
+ "unrealized_pnl": float(r.quantity * (r.current_price - r.avg_entry_price)),
+ }
+ for r in rows
+ ]
+ except Exception as exc:
+ logger.error("Failed to get positions: %s", exc)
+ raise HTTPException(status_code=500, detail="Failed to retrieve positions")
@router.get("/snapshots")
async def get_snapshots(request: Request, days: int = 30):
"""Get portfolio snapshots for the last N days."""
- db = request.app.state.db
- snapshots = await db.get_portfolio_snapshots(days=days)
- return [
- {
- "total_value": float(s["total_value"]),
- "realized_pnl": float(s["realized_pnl"]),
- "unrealized_pnl": float(s["unrealized_pnl"]),
- "snapshot_at": s["snapshot_at"].isoformat(),
- }
- for s in snapshots
- ]
+ try:
+ db = request.app.state.db
+ snapshots = await db.get_portfolio_snapshots(days=days)
+ return [
+ {
+ "total_value": float(s["total_value"]),
+ "realized_pnl": float(s["realized_pnl"]),
+ "unrealized_pnl": float(s["unrealized_pnl"]),
+ "snapshot_at": s["snapshot_at"].isoformat(),
+ }
+ for s in snapshots
+ ]
+ except Exception as exc:
+ logger.error("Failed to get snapshots: %s", exc)
+ raise HTTPException(status_code=500, detail="Failed to retrieve snapshots")
diff --git a/services/api/src/trading_api/routers/strategies.py b/services/api/src/trading_api/routers/strategies.py
index 2861eec..e968529 100644
--- a/services/api/src/trading_api/routers/strategies.py
+++ b/services/api/src/trading_api/routers/strategies.py
@@ -1,30 +1,37 @@
"""Strategy endpoints."""
+import logging
import sys
from pathlib import Path
-from fastapi import APIRouter
+from fastapi import APIRouter, HTTPException
# Add strategy-engine to path for plugin loading
_STRATEGY_DIR = Path(__file__).resolve().parents[5] / "strategy-engine"
if str(_STRATEGY_DIR) not in sys.path:
sys.path.insert(0, str(_STRATEGY_DIR))
+logger = logging.getLogger(__name__)
+
router = APIRouter()
@router.get("/")
async def list_strategies():
"""List available strategies."""
- from strategy_engine.plugin_loader import load_strategies
-
- strategies_dir = _STRATEGY_DIR / "strategies"
- strategies = load_strategies(strategies_dir)
- return [
- {
- "name": s.name,
- "warmup_period": s.warmup_period,
- "class": type(s).__name__,
- }
- for s in strategies
- ]
+ try:
+ from strategy_engine.plugin_loader import load_strategies
+
+ strategies_dir = _STRATEGY_DIR / "strategies"
+ strategies = load_strategies(strategies_dir)
+ return [
+ {
+ "name": s.name,
+ "warmup_period": s.warmup_period,
+ "class": type(s).__name__,
+ }
+ for s in strategies
+ ]
+ except Exception as exc:
+ logger.error("Failed to list strategies: %s", exc)
+ raise HTTPException(status_code=500, detail="Failed to list strategies")
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py
index d60e6c9..ce174e8 100644
--- a/services/portfolio-manager/src/portfolio_manager/main.py
+++ b/services/portfolio-manager/src/portfolio_manager/main.py
@@ -16,6 +16,11 @@ from portfolio_manager.portfolio import PortfolioTracker
ORDERS_STREAM = "orders"
+# Health check port: base (HEALTH_PORT, default 8080) + offset
+# data-collector: +0 (8080), strategy-engine: +1 (8081)
+# order-executor: +2 (8082), portfolio-manager: +3 (8083)
+HEALTH_PORT_OFFSET = 3
+
async def save_snapshot(
db: Database,
@@ -45,11 +50,11 @@ async def snapshot_loop(
) -> None:
"""Periodically save portfolio snapshots and send daily summary."""
while True:
- await asyncio.sleep(interval_hours * 3600)
try:
await save_snapshot(db, tracker, notifier, log)
except Exception as exc:
log.error("snapshot_failed", error=str(exc))
+ await asyncio.sleep(interval_hours * 3600)
async def run() -> None:
@@ -64,7 +69,7 @@ async def run() -> None:
tracker = PortfolioTracker()
health = HealthCheckServer(
- "portfolio-manager", port=config.health_port + 3, auth_token=config.metrics_auth_token
+ "portfolio-manager", port=config.health_port + HEALTH_PORT_OFFSET, auth_token=config.metrics_auth_token
)
health.register_check("redis", broker.ping)
await health.start()
diff --git a/shared/alembic.ini b/shared/alembic.ini
index 2c4fd1f..7206a9b 100644
--- a/shared/alembic.ini
+++ b/shared/alembic.ini
@@ -1,6 +1,6 @@
[alembic]
script_location = alembic
-sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/trading
+sqlalchemy.url = postgresql+asyncpg://trading:trading@localhost:5432/trading
[loggers]
keys = root,sqlalchemy,alembic
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index f4cabfd..d33dfe1 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -306,6 +306,7 @@ class TestTransactionContextManager:
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
+ mock_session.add = MagicMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)