diff options
| -rw-r--r-- | docker-compose.yml | 5 | ||||
| -rw-r--r-- | services/api/src/trading_api/routers/orders.py | 90 | ||||
| -rw-r--r-- | services/api/src/trading_api/routers/portfolio.py | 64 | ||||
| -rw-r--r-- | services/api/src/trading_api/routers/strategies.py | 33 | ||||
| -rw-r--r-- | services/portfolio-manager/src/portfolio_manager/main.py | 9 | ||||
| -rw-r--r-- | shared/alembic.ini | 2 | ||||
| -rw-r--r-- | shared/tests/test_db.py | 1 |
7 files changed, 123 insertions, 81 deletions
diff --git a/docker-compose.yml b/docker-compose.yml index 1b72e8d..e981f74 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -115,6 +115,11 @@ services: condition: service_healthy postgres: condition: service_healthy + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] + interval: 10s + timeout: 5s + retries: 3 restart: unless-stopped loki: diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py index d0b9fa6..c69dc10 100644 --- a/services/api/src/trading_api/routers/orders.py +++ b/services/api/src/trading_api/routers/orders.py @@ -1,55 +1,67 @@ """Order endpoints.""" -from fastapi import APIRouter, Request +import logging + +from fastapi import APIRouter, HTTPException, Request from shared.sa_models import OrderRow, SignalRow from sqlalchemy import select +logger = logging.getLogger(__name__) + router = APIRouter() @router.get("/") async def get_orders(request: Request, limit: int = 50): """Get recent orders.""" - db = request.app.state.db - async with db.get_session() as session: - stmt = select(OrderRow).order_by(OrderRow.created_at.desc()).limit(limit) - result = await session.execute(stmt) - rows = result.scalars().all() - return [ - { - "id": r.id, - "signal_id": r.signal_id, - "symbol": r.symbol, - "side": r.side, - "type": r.type, - "price": float(r.price), - "quantity": float(r.quantity), - "status": r.status, - "created_at": r.created_at.isoformat() if r.created_at else None, - "filled_at": r.filled_at.isoformat() if r.filled_at else None, - } - for r in rows - ] + try: + db = request.app.state.db + async with db.get_session() as session: + stmt = select(OrderRow).order_by(OrderRow.created_at.desc()).limit(limit) + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + "id": r.id, + "signal_id": r.signal_id, + "symbol": r.symbol, + "side": r.side, + "type": r.type, + "price": float(r.price), + "quantity": float(r.quantity), + "status": r.status, + "created_at": r.created_at.isoformat() if r.created_at else None, + "filled_at": r.filled_at.isoformat() if r.filled_at else None, + } + for r in rows + ] + except Exception as exc: + logger.error("Failed to get orders: %s", exc) + raise HTTPException(status_code=500, detail="Failed to retrieve orders") @router.get("/signals") async def get_signals(request: Request, limit: int = 50): """Get recent signals.""" - db = request.app.state.db - async with db.get_session() as session: - stmt = select(SignalRow).order_by(SignalRow.created_at.desc()).limit(limit) - result = await session.execute(stmt) - rows = result.scalars().all() - return [ - { - "id": r.id, - "strategy": r.strategy, - "symbol": r.symbol, - "side": r.side, - "price": float(r.price), - "quantity": float(r.quantity), - "reason": r.reason, - "created_at": r.created_at.isoformat() if r.created_at else None, - } - for r in rows - ] + try: + db = request.app.state.db + async with db.get_session() as session: + stmt = select(SignalRow).order_by(SignalRow.created_at.desc()).limit(limit) + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + "id": r.id, + "strategy": r.strategy, + "symbol": r.symbol, + "side": r.side, + "price": float(r.price), + "quantity": float(r.quantity), + "reason": r.reason, + "created_at": r.created_at.isoformat() if r.created_at else None, + } + for r in rows + ] + except Exception as exc: + logger.error("Failed to get signals: %s", exc) + raise HTTPException(status_code=500, detail="Failed to retrieve signals") diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py index 3b30e1d..d76d85d 100644 --- a/services/api/src/trading_api/routers/portfolio.py +++ b/services/api/src/trading_api/routers/portfolio.py @@ -1,42 +1,54 @@ """Portfolio endpoints.""" -from fastapi import APIRouter, Request +import logging + +from fastapi import APIRouter, HTTPException, Request from shared.sa_models import PositionRow from sqlalchemy import select +logger = logging.getLogger(__name__) + router = APIRouter() @router.get("/positions") async def get_positions(request: Request): """Get all current positions.""" - db = request.app.state.db - async with db.get_session() as session: - result = await session.execute(select(PositionRow)) - rows = result.scalars().all() - return [ - { - "symbol": r.symbol, - "quantity": float(r.quantity), - "avg_entry_price": float(r.avg_entry_price), - "current_price": float(r.current_price), - "unrealized_pnl": float(r.quantity * (r.current_price - r.avg_entry_price)), - } - for r in rows - ] + try: + db = request.app.state.db + async with db.get_session() as session: + result = await session.execute(select(PositionRow)) + rows = result.scalars().all() + return [ + { + "symbol": r.symbol, + "quantity": float(r.quantity), + "avg_entry_price": float(r.avg_entry_price), + "current_price": float(r.current_price), + "unrealized_pnl": float(r.quantity * (r.current_price - r.avg_entry_price)), + } + for r in rows + ] + except Exception as exc: + logger.error("Failed to get positions: %s", exc) + raise HTTPException(status_code=500, detail="Failed to retrieve positions") @router.get("/snapshots") async def get_snapshots(request: Request, days: int = 30): """Get portfolio snapshots for the last N days.""" - db = request.app.state.db - snapshots = await db.get_portfolio_snapshots(days=days) - return [ - { - "total_value": float(s["total_value"]), - "realized_pnl": float(s["realized_pnl"]), - "unrealized_pnl": float(s["unrealized_pnl"]), - "snapshot_at": s["snapshot_at"].isoformat(), - } - for s in snapshots - ] + try: + db = request.app.state.db + snapshots = await db.get_portfolio_snapshots(days=days) + return [ + { + "total_value": float(s["total_value"]), + "realized_pnl": float(s["realized_pnl"]), + "unrealized_pnl": float(s["unrealized_pnl"]), + "snapshot_at": s["snapshot_at"].isoformat(), + } + for s in snapshots + ] + except Exception as exc: + logger.error("Failed to get snapshots: %s", exc) + raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") diff --git a/services/api/src/trading_api/routers/strategies.py b/services/api/src/trading_api/routers/strategies.py index 2861eec..e968529 100644 --- a/services/api/src/trading_api/routers/strategies.py +++ b/services/api/src/trading_api/routers/strategies.py @@ -1,30 +1,37 @@ """Strategy endpoints.""" +import logging import sys from pathlib import Path -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException # Add strategy-engine to path for plugin loading _STRATEGY_DIR = Path(__file__).resolve().parents[5] / "strategy-engine" if str(_STRATEGY_DIR) not in sys.path: sys.path.insert(0, str(_STRATEGY_DIR)) +logger = logging.getLogger(__name__) + router = APIRouter() @router.get("/") async def list_strategies(): """List available strategies.""" - from strategy_engine.plugin_loader import load_strategies - - strategies_dir = _STRATEGY_DIR / "strategies" - strategies = load_strategies(strategies_dir) - return [ - { - "name": s.name, - "warmup_period": s.warmup_period, - "class": type(s).__name__, - } - for s in strategies - ] + try: + from strategy_engine.plugin_loader import load_strategies + + strategies_dir = _STRATEGY_DIR / "strategies" + strategies = load_strategies(strategies_dir) + return [ + { + "name": s.name, + "warmup_period": s.warmup_period, + "class": type(s).__name__, + } + for s in strategies + ] + except Exception as exc: + logger.error("Failed to list strategies: %s", exc) + raise HTTPException(status_code=500, detail="Failed to list strategies") diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py index d60e6c9..ce174e8 100644 --- a/services/portfolio-manager/src/portfolio_manager/main.py +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -16,6 +16,11 @@ from portfolio_manager.portfolio import PortfolioTracker ORDERS_STREAM = "orders" +# Health check port: base (HEALTH_PORT, default 8080) + offset +# data-collector: +0 (8080), strategy-engine: +1 (8081) +# order-executor: +2 (8082), portfolio-manager: +3 (8083) +HEALTH_PORT_OFFSET = 3 + async def save_snapshot( db: Database, @@ -45,11 +50,11 @@ async def snapshot_loop( ) -> None: """Periodically save portfolio snapshots and send daily summary.""" while True: - await asyncio.sleep(interval_hours * 3600) try: await save_snapshot(db, tracker, notifier, log) except Exception as exc: log.error("snapshot_failed", error=str(exc)) + await asyncio.sleep(interval_hours * 3600) async def run() -> None: @@ -64,7 +69,7 @@ async def run() -> None: tracker = PortfolioTracker() health = HealthCheckServer( - "portfolio-manager", port=config.health_port + 3, auth_token=config.metrics_auth_token + "portfolio-manager", port=config.health_port + HEALTH_PORT_OFFSET, auth_token=config.metrics_auth_token ) health.register_check("redis", broker.ping) await health.start() diff --git a/shared/alembic.ini b/shared/alembic.ini index 2c4fd1f..7206a9b 100644 --- a/shared/alembic.ini +++ b/shared/alembic.ini @@ -1,6 +1,6 @@ [alembic] script_location = alembic -sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/trading +sqlalchemy.url = postgresql+asyncpg://trading:trading@localhost:5432/trading [loggers] keys = root,sqlalchemy,alembic diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index f4cabfd..d33dfe1 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -306,6 +306,7 @@ class TestTransactionContextManager: db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() + mock_session.add = MagicMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) |
