summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:15:23 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:15:23 +0900
commit8c11cae987292421840658585c0667706790c8ca (patch)
tree5ab5571bf1bdc36e3ae3a09c0c29f49ca74c1135
parenta6bf0057d32df7ed0a1d6ec6d19daf74a0de5c0f (diff)
feat(portfolio): add periodic portfolio snapshots and daily Telegram summary
-rw-r--r--services/portfolio-manager/src/portfolio_manager/main.py46
-rw-r--r--services/portfolio-manager/tests/test_snapshot.py67
-rw-r--r--shared/src/shared/db.py47
-rw-r--r--shared/tests/test_db.py79
4 files changed, 237 insertions, 2 deletions
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py
index a1c73be..02df5d2 100644
--- a/services/portfolio-manager/src/portfolio_manager/main.py
+++ b/services/portfolio-manager/src/portfolio_manager/main.py
@@ -1,8 +1,10 @@
"""Portfolio Manager Service entry point."""
import asyncio
+from decimal import Decimal
from shared.broker import RedisBroker
+from shared.db import Database
from shared.events import Event, OrderEvent
from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
@@ -15,6 +17,41 @@ from portfolio_manager.portfolio import PortfolioTracker
ORDERS_STREAM = "orders"
+async def save_snapshot(
+ db: Database,
+ tracker: PortfolioTracker,
+ notifier: TelegramNotifier,
+ log,
+) -> None:
+ """Compute and save a portfolio snapshot, then send a daily Telegram summary."""
+ positions = tracker.get_all_positions()
+ total_value = sum(p.quantity * p.current_price for p in positions)
+ unrealized = sum(p.unrealized_pnl for p in positions)
+ await db.insert_portfolio_snapshot(
+ total_value=total_value,
+ realized_pnl=Decimal("0"), # TODO: track realized PnL
+ unrealized_pnl=unrealized,
+ )
+ await notifier.send_daily_summary(positions, total_value, unrealized)
+ log.info("snapshot_saved", total_value=str(total_value), positions=len(positions))
+
+
+async def snapshot_loop(
+ db: Database,
+ tracker: PortfolioTracker,
+ notifier: TelegramNotifier,
+ interval_hours: int,
+ log,
+) -> None:
+ """Periodically save portfolio snapshots and send daily summary."""
+ while True:
+ await asyncio.sleep(interval_hours * 3600)
+ try:
+ await save_snapshot(db, tracker, notifier, log)
+ except Exception as exc:
+ log.error("snapshot_failed", error=str(exc))
+
+
async def run() -> None:
config = PortfolioConfig()
log = setup_logging("portfolio-manager", config.log_level, config.log_format)
@@ -31,6 +68,13 @@ async def run() -> None:
await health.start()
metrics.service_up.labels(service="portfolio-manager").set(1)
+ db = Database(config.database_url)
+ await db.connect()
+
+ snapshot_task = asyncio.create_task(
+ snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log)
+ )
+
last_id = "$"
log.info("service_started", stream=ORDERS_STREAM)
@@ -67,9 +111,11 @@ async def run() -> None:
await notifier.send_error(str(exc), "portfolio-manager")
raise
finally:
+ snapshot_task.cancel()
metrics.service_up.labels(service="portfolio-manager").set(0)
await notifier.close()
await broker.close()
+ await db.close()
def main() -> None:
diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py
new file mode 100644
index 0000000..89d23d7
--- /dev/null
+++ b/services/portfolio-manager/tests/test_snapshot.py
@@ -0,0 +1,67 @@
+"""Tests for save_snapshot in portfolio-manager."""
+
+import pytest
+from decimal import Decimal
+from unittest.mock import AsyncMock, MagicMock
+
+from shared.models import Position
+
+
+class TestSaveSnapshot:
+ @pytest.mark.asyncio
+ async def test_save_snapshot_saves_to_db_and_notifies(self):
+ from portfolio_manager.main import save_snapshot
+
+ pos = Position(
+ symbol="BTCUSDT",
+ quantity=Decimal("0.5"),
+ avg_entry_price=Decimal("50000"),
+ current_price=Decimal("52000"),
+ )
+
+ tracker = MagicMock()
+ tracker.get_all_positions.return_value = [pos]
+
+ db = AsyncMock()
+ notifier = AsyncMock()
+ log = MagicMock()
+
+ await save_snapshot(db, tracker, notifier, log)
+
+ expected_total = Decimal("0.5") * Decimal("52000") # 26000
+ expected_unrealized = Decimal("0.5") * (Decimal("52000") - Decimal("50000")) # 1000
+
+ db.insert_portfolio_snapshot.assert_awaited_once_with(
+ total_value=expected_total,
+ realized_pnl=Decimal("0"),
+ unrealized_pnl=expected_unrealized,
+ )
+ notifier.send_daily_summary.assert_awaited_once_with(
+ [pos], expected_total, expected_unrealized
+ )
+ log.info.assert_called_once_with(
+ "snapshot_saved",
+ total_value=str(expected_total),
+ positions=1,
+ )
+
+ @pytest.mark.asyncio
+ async def test_save_snapshot_empty_positions(self):
+ from portfolio_manager.main import save_snapshot
+
+ tracker = MagicMock()
+ tracker.get_all_positions.return_value = []
+
+ db = AsyncMock()
+ notifier = AsyncMock()
+ log = MagicMock()
+
+ await save_snapshot(db, tracker, notifier, log)
+
+ db.insert_portfolio_snapshot.assert_awaited_once_with(
+ total_value=Decimal("0"),
+ realized_pnl=Decimal("0"),
+ unrealized_pnl=Decimal("0"),
+ )
+ notifier.send_daily_summary.assert_awaited_once()
+ log.info.assert_called_once()
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 515ba2c..901e293 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,14 +1,15 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
from contextlib import asynccontextmanager
-from datetime import datetime
+from datetime import datetime, timedelta, timezone
+from decimal import Decimal
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from shared.models import Candle, Signal, Order, OrderStatus
-from shared.sa_models import Base, CandleRow, SignalRow, OrderRow
+from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow
class Database:
@@ -152,3 +153,45 @@ class Database:
await session.rollback()
raise
return [dict(row._mapping) for row in rows]
+
+ async def insert_portfolio_snapshot(
+ self,
+ total_value: Decimal,
+ realized_pnl: Decimal,
+ unrealized_pnl: Decimal,
+ ) -> None:
+ """Insert a portfolio snapshot."""
+ async with self.get_session() as session:
+ try:
+ row = PortfolioSnapshotRow(
+ total_value=total_value,
+ realized_pnl=realized_pnl,
+ unrealized_pnl=unrealized_pnl,
+ snapshot_at=datetime.now(timezone.utc),
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]:
+ """Retrieve recent portfolio snapshots."""
+ async with self.get_session() as session:
+ since = datetime.now(timezone.utc) - timedelta(days=days)
+ stmt = (
+ select(PortfolioSnapshotRow)
+ .where(PortfolioSnapshotRow.snapshot_at >= since)
+ .order_by(PortfolioSnapshotRow.snapshot_at.desc())
+ )
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ return [
+ {
+ "total_value": r.total_value,
+ "realized_pnl": r.realized_pnl,
+ "unrealized_pnl": r.unrealized_pnl,
+ "snapshot_at": r.snapshot_at,
+ }
+ for r in rows
+ ]
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 04efe9e..f4cabfd 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -335,3 +335,82 @@ class TestTransactionContextManager:
mock_session.rollback.assert_awaited_once()
mock_session.commit.assert_not_awaited()
+
+
+class TestInsertPortfolioSnapshot:
+ @pytest.mark.asyncio
+ async def test_insert_portfolio_snapshot_uses_add_and_commit(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ await db.insert_portfolio_snapshot(
+ total_value=Decimal("10000"),
+ realized_pnl=Decimal("0"),
+ unrealized_pnl=Decimal("500"),
+ )
+
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_awaited_once()
+
+ @pytest.mark.asyncio
+ async def test_insert_portfolio_snapshot_rollback_on_error(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.commit = AsyncMock(side_effect=RuntimeError("db error"))
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ with pytest.raises(RuntimeError, match="db error"):
+ await db.insert_portfolio_snapshot(
+ total_value=Decimal("10000"),
+ realized_pnl=Decimal("0"),
+ unrealized_pnl=Decimal("500"),
+ )
+
+ mock_session.rollback.assert_awaited_once()
+
+
+class TestGetPortfolioSnapshots:
+ @pytest.mark.asyncio
+ async def test_get_portfolio_snapshots_returns_list_of_dicts(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_row = MagicMock()
+ mock_row.total_value = Decimal("10000")
+ mock_row.realized_pnl = Decimal("0")
+ mock_row.unrealized_pnl = Decimal("500")
+ mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.all.return_value = [mock_row]
+
+ mock_session = AsyncMock()
+ mock_session.execute = AsyncMock(return_value=mock_result)
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ result = await db.get_portfolio_snapshots(days=30)
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ assert result[0]["total_value"] == Decimal("10000")
+ assert result[0]["unrealized_pnl"] == Decimal("500")
+ mock_session.execute.assert_awaited_once()