summaryrefslogtreecommitdiff
path: root/shared/src
diff options
context:
space:
mode:
Diffstat (limited to 'shared/src')
-rw-r--r--shared/src/shared/db.py235
1 files changed, 87 insertions, 148 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 6bddd7c..95e487e 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,159 +1,95 @@
-"""Database layer using asyncpg for the trading platform."""
-from datetime import datetime, timezone
+"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
+from datetime import datetime
from typing import Optional
-import asyncpg
+from sqlalchemy import select, update
+from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from shared.models import Candle, Signal, Order, OrderStatus
-
-
-_INIT_SQL = """
-CREATE TABLE IF NOT EXISTS candles (
- symbol TEXT NOT NULL,
- timeframe TEXT NOT NULL,
- open_time TIMESTAMPTZ NOT NULL,
- open NUMERIC NOT NULL,
- high NUMERIC NOT NULL,
- low NUMERIC NOT NULL,
- close NUMERIC NOT NULL,
- volume NUMERIC NOT NULL,
- PRIMARY KEY (symbol, timeframe, open_time)
-);
-
-CREATE TABLE IF NOT EXISTS signals (
- id TEXT PRIMARY KEY,
- strategy TEXT NOT NULL,
- symbol TEXT NOT NULL,
- side TEXT NOT NULL,
- price NUMERIC NOT NULL,
- quantity NUMERIC NOT NULL,
- reason TEXT,
- created_at TIMESTAMPTZ NOT NULL
-);
-
-CREATE TABLE IF NOT EXISTS orders (
- id TEXT PRIMARY KEY,
- signal_id TEXT REFERENCES signals(id),
- symbol TEXT NOT NULL,
- side TEXT NOT NULL,
- type TEXT NOT NULL,
- price NUMERIC NOT NULL,
- quantity NUMERIC NOT NULL,
- status TEXT NOT NULL DEFAULT 'PENDING',
- created_at TIMESTAMPTZ NOT NULL,
- filled_at TIMESTAMPTZ
-);
-
-CREATE TABLE IF NOT EXISTS trades (
- id TEXT PRIMARY KEY,
- order_id TEXT REFERENCES orders(id),
- symbol TEXT NOT NULL,
- side TEXT NOT NULL,
- price NUMERIC NOT NULL,
- quantity NUMERIC NOT NULL,
- fee NUMERIC NOT NULL DEFAULT 0,
- traded_at TIMESTAMPTZ NOT NULL
-);
-
-CREATE TABLE IF NOT EXISTS positions (
- symbol TEXT PRIMARY KEY,
- quantity NUMERIC NOT NULL,
- avg_entry_price NUMERIC NOT NULL,
- current_price NUMERIC NOT NULL,
- updated_at TIMESTAMPTZ NOT NULL
-);
-
-CREATE TABLE IF NOT EXISTS portfolio_snapshots (
- id SERIAL PRIMARY KEY,
- total_value NUMERIC NOT NULL,
- realized_pnl NUMERIC NOT NULL,
- unrealized_pnl NUMERIC NOT NULL,
- snapshot_at TIMESTAMPTZ NOT NULL
-);
-"""
+from shared.sa_models import Base, CandleRow, SignalRow, OrderRow
class Database:
- """Async database access layer backed by asyncpg connection pool."""
+ """Async database access layer backed by SQLAlchemy async engine."""
def __init__(self, database_url: str) -> None:
+ # Auto-convert postgresql:// to postgresql+asyncpg://
+ if database_url.startswith("postgresql://"):
+ database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
self._database_url = database_url
- self._pool: Optional[asyncpg.Pool] = None
+ self._engine = None
+ self._session_factory = None
async def connect(self) -> None:
- """Create the asyncpg connection pool."""
- self._pool = await asyncpg.create_pool(self._database_url)
+ """Create the async engine, session factory, and all tables."""
+ self._engine = create_async_engine(self._database_url)
+ self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
+ async with self._engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+
+ async def init_tables(self) -> None:
+ """Alias for connect() for backward compatibility."""
+ await self.connect()
async def close(self) -> None:
- """Close the asyncpg connection pool."""
- if self._pool:
- await self._pool.close()
- self._pool = None
+ """Dispose of the async engine."""
+ if self._engine:
+ await self._engine.dispose()
+ self._engine = None
- async def init_tables(self) -> None:
- """Create all tables if they do not exist."""
- async with self._pool as conn:
- await conn.execute(_INIT_SQL)
+ def get_session(self) -> AsyncSession:
+ """Return a new async session from the factory."""
+ return self._session_factory()
async def insert_candle(self, candle: Candle) -> None:
- """Insert a candle row, ignoring duplicates."""
- sql = """
- INSERT INTO candles (symbol, timeframe, open_time, open, high, low, close, volume)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
- ON CONFLICT DO NOTHING
- """
- async with self._pool as conn:
- await conn.execute(
- sql,
- candle.symbol,
- candle.timeframe,
- candle.open_time,
- candle.open,
- candle.high,
- candle.low,
- candle.close,
- candle.volume,
- )
+ """Upsert a candle row using session.merge."""
+ row = CandleRow(
+ symbol=candle.symbol,
+ timeframe=candle.timeframe,
+ open_time=candle.open_time,
+ open=candle.open,
+ high=candle.high,
+ low=candle.low,
+ close=candle.close,
+ volume=candle.volume,
+ )
+ async with self._session_factory() as session:
+ session.merge(row)
+ await session.commit()
async def insert_signal(self, signal: Signal) -> None:
"""Insert a signal row."""
- sql = """
- INSERT INTO signals (id, strategy, symbol, side, price, quantity, reason, created_at)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
- """
- async with self._pool as conn:
- await conn.execute(
- sql,
- signal.id,
- signal.strategy,
- signal.symbol,
- signal.side.value,
- signal.price,
- signal.quantity,
- signal.reason,
- signal.created_at,
- )
+ row = SignalRow(
+ id=signal.id,
+ strategy=signal.strategy,
+ symbol=signal.symbol,
+ side=signal.side.value,
+ price=signal.price,
+ quantity=signal.quantity,
+ reason=signal.reason,
+ created_at=signal.created_at,
+ )
+ async with self._session_factory() as session:
+ session.add(row)
+ await session.commit()
async def insert_order(self, order: Order) -> None:
"""Insert an order row."""
- sql = """
- INSERT INTO orders (id, signal_id, symbol, side, type, price, quantity, status, created_at, filled_at)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
- """
- async with self._pool as conn:
- await conn.execute(
- sql,
- order.id,
- order.signal_id,
- order.symbol,
- order.side.value,
- order.type.value,
- order.price,
- order.quantity,
- order.status.value,
- order.created_at,
- order.filled_at,
- )
+ row = OrderRow(
+ id=order.id,
+ signal_id=order.signal_id,
+ symbol=order.symbol,
+ side=order.side.value,
+ type=order.type.value,
+ price=order.price,
+ quantity=order.quantity,
+ status=order.status.value,
+ created_at=order.created_at,
+ filled_at=order.filled_at,
+ )
+ async with self._session_factory() as session:
+ session.add(row)
+ await session.commit()
async def update_order_status(
self,
@@ -162,23 +98,26 @@ class Database:
filled_at: Optional[datetime] = None,
) -> None:
"""Update the status (and optionally filled_at) of an order."""
- sql = """
- UPDATE orders SET status = $2, filled_at = $3 WHERE id = $1
- """
- async with self._pool as conn:
- await conn.execute(sql, order_id, status.value, filled_at)
+ stmt = (
+ update(OrderRow)
+ .where(OrderRow.id == order_id)
+ .values(status=status.value, filled_at=filled_at)
+ )
+ async with self._session_factory() as session:
+ await session.execute(stmt)
+ await session.commit()
async def get_candles(
self, symbol: str, timeframe: str, limit: int = 500
) -> list[dict]:
"""Retrieve candles ordered by open_time descending."""
- sql = """
- SELECT symbol, timeframe, open_time, open, high, low, close, volume
- FROM candles
- WHERE symbol = $1 AND timeframe = $2
- ORDER BY open_time DESC
- LIMIT $3
- """
- async with self._pool as conn:
- rows = await conn.fetch(sql, symbol, timeframe, limit)
- return [dict(row) for row in rows]
+ stmt = (
+ select(CandleRow)
+ .where(CandleRow.symbol == symbol, CandleRow.timeframe == timeframe)
+ .order_by(CandleRow.open_time.desc())
+ .limit(limit)
+ )
+ async with self._session_factory() as session:
+ result = await session.execute(stmt)
+ rows = result.fetchall()
+ return [dict(row._mapping) for row in rows]