summaryrefslogtreecommitdiff
path: root/shared/src
diff options
context:
space:
mode:
Diffstat (limited to 'shared/src')
-rw-r--r--shared/src/shared/db.py47
1 files changed, 45 insertions, 2 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 515ba2c..901e293 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,14 +1,15 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
from contextlib import asynccontextmanager
-from datetime import datetime
+from datetime import datetime, timedelta, timezone
+from decimal import Decimal
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from shared.models import Candle, Signal, Order, OrderStatus
-from shared.sa_models import Base, CandleRow, SignalRow, OrderRow
+from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow
class Database:
@@ -152,3 +153,45 @@ class Database:
await session.rollback()
raise
return [dict(row._mapping) for row in rows]
+
+ async def insert_portfolio_snapshot(
+ self,
+ total_value: Decimal,
+ realized_pnl: Decimal,
+ unrealized_pnl: Decimal,
+ ) -> None:
+ """Insert a portfolio snapshot."""
+ async with self.get_session() as session:
+ try:
+ row = PortfolioSnapshotRow(
+ total_value=total_value,
+ realized_pnl=realized_pnl,
+ unrealized_pnl=unrealized_pnl,
+ snapshot_at=datetime.now(timezone.utc),
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]:
+ """Retrieve recent portfolio snapshots."""
+ async with self.get_session() as session:
+ since = datetime.now(timezone.utc) - timedelta(days=days)
+ stmt = (
+ select(PortfolioSnapshotRow)
+ .where(PortfolioSnapshotRow.snapshot_at >= since)
+ .order_by(PortfolioSnapshotRow.snapshot_at.desc())
+ )
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ return [
+ {
+ "total_value": r.total_value,
+ "realized_pnl": r.realized_pnl,
+ "unrealized_pnl": r.unrealized_pnl,
+ "snapshot_at": r.snapshot_at,
+ }
+ for r in rows
+ ]