diff options
| -rw-r--r-- | services/portfolio-manager/src/portfolio_manager/main.py | 46 | ||||
| -rw-r--r-- | services/portfolio-manager/tests/test_snapshot.py | 67 | ||||
| -rw-r--r-- | shared/src/shared/db.py | 47 | ||||
| -rw-r--r-- | shared/tests/test_db.py | 79 |
4 files changed, 237 insertions, 2 deletions
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py index a1c73be..02df5d2 100644 --- a/services/portfolio-manager/src/portfolio_manager/main.py +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -1,8 +1,10 @@ """Portfolio Manager Service entry point.""" import asyncio +from decimal import Decimal from shared.broker import RedisBroker +from shared.db import Database from shared.events import Event, OrderEvent from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging @@ -15,6 +17,41 @@ from portfolio_manager.portfolio import PortfolioTracker ORDERS_STREAM = "orders" +async def save_snapshot( + db: Database, + tracker: PortfolioTracker, + notifier: TelegramNotifier, + log, +) -> None: + """Compute and save a portfolio snapshot, then send a daily Telegram summary.""" + positions = tracker.get_all_positions() + total_value = sum(p.quantity * p.current_price for p in positions) + unrealized = sum(p.unrealized_pnl for p in positions) + await db.insert_portfolio_snapshot( + total_value=total_value, + realized_pnl=Decimal("0"), # TODO: track realized PnL + unrealized_pnl=unrealized, + ) + await notifier.send_daily_summary(positions, total_value, unrealized) + log.info("snapshot_saved", total_value=str(total_value), positions=len(positions)) + + +async def snapshot_loop( + db: Database, + tracker: PortfolioTracker, + notifier: TelegramNotifier, + interval_hours: int, + log, +) -> None: + """Periodically save portfolio snapshots and send daily summary.""" + while True: + await asyncio.sleep(interval_hours * 3600) + try: + await save_snapshot(db, tracker, notifier, log) + except Exception as exc: + log.error("snapshot_failed", error=str(exc)) + + async def run() -> None: config = PortfolioConfig() log = setup_logging("portfolio-manager", config.log_level, config.log_format) @@ -31,6 +68,13 @@ async def run() -> None: await health.start() metrics.service_up.labels(service="portfolio-manager").set(1) + db = Database(config.database_url) + await db.connect() + + snapshot_task = asyncio.create_task( + snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log) + ) + last_id = "$" log.info("service_started", stream=ORDERS_STREAM) @@ -67,9 +111,11 @@ async def run() -> None: await notifier.send_error(str(exc), "portfolio-manager") raise finally: + snapshot_task.cancel() metrics.service_up.labels(service="portfolio-manager").set(0) await notifier.close() await broker.close() + await db.close() def main() -> None: diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py new file mode 100644 index 0000000..89d23d7 --- /dev/null +++ b/services/portfolio-manager/tests/test_snapshot.py @@ -0,0 +1,67 @@ +"""Tests for save_snapshot in portfolio-manager.""" + +import pytest +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock + +from shared.models import Position + + +class TestSaveSnapshot: + @pytest.mark.asyncio + async def test_save_snapshot_saves_to_db_and_notifies(self): + from portfolio_manager.main import save_snapshot + + pos = Position( + symbol="BTCUSDT", + quantity=Decimal("0.5"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("52000"), + ) + + tracker = MagicMock() + tracker.get_all_positions.return_value = [pos] + + db = AsyncMock() + notifier = AsyncMock() + log = MagicMock() + + await save_snapshot(db, tracker, notifier, log) + + expected_total = Decimal("0.5") * Decimal("52000") # 26000 + expected_unrealized = Decimal("0.5") * (Decimal("52000") - Decimal("50000")) # 1000 + + db.insert_portfolio_snapshot.assert_awaited_once_with( + total_value=expected_total, + realized_pnl=Decimal("0"), + unrealized_pnl=expected_unrealized, + ) + notifier.send_daily_summary.assert_awaited_once_with( + [pos], expected_total, expected_unrealized + ) + log.info.assert_called_once_with( + "snapshot_saved", + total_value=str(expected_total), + positions=1, + ) + + @pytest.mark.asyncio + async def test_save_snapshot_empty_positions(self): + from portfolio_manager.main import save_snapshot + + tracker = MagicMock() + tracker.get_all_positions.return_value = [] + + db = AsyncMock() + notifier = AsyncMock() + log = MagicMock() + + await save_snapshot(db, tracker, notifier, log) + + db.insert_portfolio_snapshot.assert_awaited_once_with( + total_value=Decimal("0"), + realized_pnl=Decimal("0"), + unrealized_pnl=Decimal("0"), + ) + notifier.send_daily_summary.assert_awaited_once() + log.info.assert_called_once() diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 515ba2c..901e293 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,14 +1,15 @@ """Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" from contextlib import asynccontextmanager -from datetime import datetime +from datetime import datetime, timedelta, timezone +from decimal import Decimal from typing import Optional from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from shared.models import Candle, Signal, Order, OrderStatus -from shared.sa_models import Base, CandleRow, SignalRow, OrderRow +from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow class Database: @@ -152,3 +153,45 @@ class Database: await session.rollback() raise return [dict(row._mapping) for row in rows] + + async def insert_portfolio_snapshot( + self, + total_value: Decimal, + realized_pnl: Decimal, + unrealized_pnl: Decimal, + ) -> None: + """Insert a portfolio snapshot.""" + async with self.get_session() as session: + try: + row = PortfolioSnapshotRow( + total_value=total_value, + realized_pnl=realized_pnl, + unrealized_pnl=unrealized_pnl, + snapshot_at=datetime.now(timezone.utc), + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]: + """Retrieve recent portfolio snapshots.""" + async with self.get_session() as session: + since = datetime.now(timezone.utc) - timedelta(days=days) + stmt = ( + select(PortfolioSnapshotRow) + .where(PortfolioSnapshotRow.snapshot_at >= since) + .order_by(PortfolioSnapshotRow.snapshot_at.desc()) + ) + result = await session.execute(stmt) + rows = result.scalars().all() + return [ + { + "total_value": r.total_value, + "realized_pnl": r.realized_pnl, + "unrealized_pnl": r.unrealized_pnl, + "snapshot_at": r.snapshot_at, + } + for r in rows + ] diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 04efe9e..f4cabfd 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -335,3 +335,82 @@ class TestTransactionContextManager: mock_session.rollback.assert_awaited_once() mock_session.commit.assert_not_awaited() + + +class TestInsertPortfolioSnapshot: + @pytest.mark.asyncio + async def test_insert_portfolio_snapshot_uses_add_and_commit(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + await db.insert_portfolio_snapshot( + total_value=Decimal("10000"), + realized_pnl=Decimal("0"), + unrealized_pnl=Decimal("500"), + ) + + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + @pytest.mark.asyncio + async def test_insert_portfolio_snapshot_rollback_on_error(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + with pytest.raises(RuntimeError, match="db error"): + await db.insert_portfolio_snapshot( + total_value=Decimal("10000"), + realized_pnl=Decimal("0"), + unrealized_pnl=Decimal("500"), + ) + + mock_session.rollback.assert_awaited_once() + + +class TestGetPortfolioSnapshots: + @pytest.mark.asyncio + async def test_get_portfolio_snapshots_returns_list_of_dicts(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_row = MagicMock() + mock_row.total_value = Decimal("10000") + mock_row.realized_pnl = Decimal("0") + mock_row.unrealized_pnl = Decimal("500") + mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_row] + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + result = await db.get_portfolio_snapshots(days=30) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["total_value"] == Decimal("10000") + assert result[0]["unrealized_pnl"] == Decimal("500") + mock_session.execute.assert_awaited_once() |
