summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:15:23 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:15:23 +0900
commit8c11cae987292421840658585c0667706790c8ca (patch)
tree5ab5571bf1bdc36e3ae3a09c0c29f49ca74c1135 /shared
parenta6bf0057d32df7ed0a1d6ec6d19daf74a0de5c0f (diff)
feat(portfolio): add periodic portfolio snapshots and daily Telegram summary
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/db.py47
-rw-r--r--shared/tests/test_db.py79
2 files changed, 124 insertions, 2 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 515ba2c..901e293 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,14 +1,15 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
from contextlib import asynccontextmanager
-from datetime import datetime
+from datetime import datetime, timedelta, timezone
+from decimal import Decimal
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from shared.models import Candle, Signal, Order, OrderStatus
-from shared.sa_models import Base, CandleRow, SignalRow, OrderRow
+from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow
class Database:
@@ -152,3 +153,45 @@ class Database:
await session.rollback()
raise
return [dict(row._mapping) for row in rows]
+
+ async def insert_portfolio_snapshot(
+ self,
+ total_value: Decimal,
+ realized_pnl: Decimal,
+ unrealized_pnl: Decimal,
+ ) -> None:
+ """Insert a portfolio snapshot."""
+ async with self.get_session() as session:
+ try:
+ row = PortfolioSnapshotRow(
+ total_value=total_value,
+ realized_pnl=realized_pnl,
+ unrealized_pnl=unrealized_pnl,
+ snapshot_at=datetime.now(timezone.utc),
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]:
+ """Retrieve recent portfolio snapshots."""
+ async with self.get_session() as session:
+ since = datetime.now(timezone.utc) - timedelta(days=days)
+ stmt = (
+ select(PortfolioSnapshotRow)
+ .where(PortfolioSnapshotRow.snapshot_at >= since)
+ .order_by(PortfolioSnapshotRow.snapshot_at.desc())
+ )
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ return [
+ {
+ "total_value": r.total_value,
+ "realized_pnl": r.realized_pnl,
+ "unrealized_pnl": r.unrealized_pnl,
+ "snapshot_at": r.snapshot_at,
+ }
+ for r in rows
+ ]
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 04efe9e..f4cabfd 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -335,3 +335,82 @@ class TestTransactionContextManager:
mock_session.rollback.assert_awaited_once()
mock_session.commit.assert_not_awaited()
+
+
+class TestInsertPortfolioSnapshot:
+ @pytest.mark.asyncio
+ async def test_insert_portfolio_snapshot_uses_add_and_commit(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ await db.insert_portfolio_snapshot(
+ total_value=Decimal("10000"),
+ realized_pnl=Decimal("0"),
+ unrealized_pnl=Decimal("500"),
+ )
+
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_awaited_once()
+
+ @pytest.mark.asyncio
+ async def test_insert_portfolio_snapshot_rollback_on_error(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.commit = AsyncMock(side_effect=RuntimeError("db error"))
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ with pytest.raises(RuntimeError, match="db error"):
+ await db.insert_portfolio_snapshot(
+ total_value=Decimal("10000"),
+ realized_pnl=Decimal("0"),
+ unrealized_pnl=Decimal("500"),
+ )
+
+ mock_session.rollback.assert_awaited_once()
+
+
+class TestGetPortfolioSnapshots:
+ @pytest.mark.asyncio
+ async def test_get_portfolio_snapshots_returns_list_of_dicts(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_row = MagicMock()
+ mock_row.total_value = Decimal("10000")
+ mock_row.realized_pnl = Decimal("0")
+ mock_row.unrealized_pnl = Decimal("500")
+ mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.all.return_value = [mock_row]
+
+ mock_session = AsyncMock()
+ mock_session.execute = AsyncMock(return_value=mock_result)
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ result = await db.get_portfolio_snapshots(days=30)
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ assert result[0]["total_value"] == Decimal("10000")
+ assert result[0]["unrealized_pnl"] == Decimal("500")
+ mock_session.execute.assert_awaited_once()