From c1d53dbc173f87fe23e179f21c9a713df6484dae Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:04:45 +0900 Subject: feat: rewrite database layer from asyncpg to SQLAlchemy 2.0 async ORM Replace raw asyncpg SQL with SQLAlchemy async engine, async_sessionmaker, and ORM operations. Uses session.merge for candle upserts, session.add for signal/order inserts, update() for status changes, select() for queries. Auto-converts postgresql:// URLs to postgresql+asyncpg://. Keeps init_tables() as backward-compatible alias for connect(). --- shared/src/shared/db.py | 235 ++++++++++++++++++------------------------------ 1 file changed, 87 insertions(+), 148 deletions(-) (limited to 'shared/src') diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 6bddd7c..95e487e 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,159 +1,95 @@ -"""Database layer using asyncpg for the trading platform.""" -from datetime import datetime, timezone +"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" +from datetime import datetime from typing import Optional -import asyncpg +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from shared.models import Candle, Signal, Order, OrderStatus - - -_INIT_SQL = """ -CREATE TABLE IF NOT EXISTS candles ( - symbol TEXT NOT NULL, - timeframe TEXT NOT NULL, - open_time TIMESTAMPTZ NOT NULL, - open NUMERIC NOT NULL, - high NUMERIC NOT NULL, - low NUMERIC NOT NULL, - close NUMERIC NOT NULL, - volume NUMERIC NOT NULL, - PRIMARY KEY (symbol, timeframe, open_time) -); - -CREATE TABLE IF NOT EXISTS signals ( - id TEXT PRIMARY KEY, - strategy TEXT NOT NULL, - symbol TEXT NOT NULL, - side TEXT NOT NULL, - price NUMERIC NOT NULL, - quantity NUMERIC NOT NULL, - reason TEXT, - created_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS orders ( - id TEXT PRIMARY KEY, - signal_id TEXT REFERENCES signals(id), - symbol TEXT NOT NULL, - side TEXT NOT NULL, - type TEXT NOT NULL, - price NUMERIC NOT NULL, - quantity NUMERIC NOT NULL, - status TEXT NOT NULL DEFAULT 'PENDING', - created_at TIMESTAMPTZ NOT NULL, - filled_at TIMESTAMPTZ -); - -CREATE TABLE IF NOT EXISTS trades ( - id TEXT PRIMARY KEY, - order_id TEXT REFERENCES orders(id), - symbol TEXT NOT NULL, - side TEXT NOT NULL, - price NUMERIC NOT NULL, - quantity NUMERIC NOT NULL, - fee NUMERIC NOT NULL DEFAULT 0, - traded_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS positions ( - symbol TEXT PRIMARY KEY, - quantity NUMERIC NOT NULL, - avg_entry_price NUMERIC NOT NULL, - current_price NUMERIC NOT NULL, - updated_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS portfolio_snapshots ( - id SERIAL PRIMARY KEY, - total_value NUMERIC NOT NULL, - realized_pnl NUMERIC NOT NULL, - unrealized_pnl NUMERIC NOT NULL, - snapshot_at TIMESTAMPTZ NOT NULL -); -""" +from shared.sa_models import Base, CandleRow, SignalRow, OrderRow class Database: - """Async database access layer backed by asyncpg connection pool.""" + """Async database access layer backed by SQLAlchemy async engine.""" def __init__(self, database_url: str) -> None: + # Auto-convert postgresql:// to postgresql+asyncpg:// + if database_url.startswith("postgresql://"): + database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1) self._database_url = database_url - self._pool: Optional[asyncpg.Pool] = None + self._engine = None + self._session_factory = None async def connect(self) -> None: - """Create the asyncpg connection pool.""" - self._pool = await asyncpg.create_pool(self._database_url) + """Create the async engine, session factory, and all tables.""" + self._engine = create_async_engine(self._database_url) + self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def init_tables(self) -> None: + """Alias for connect() for backward compatibility.""" + await self.connect() async def close(self) -> None: - """Close the asyncpg connection pool.""" - if self._pool: - await self._pool.close() - self._pool = None + """Dispose of the async engine.""" + if self._engine: + await self._engine.dispose() + self._engine = None - async def init_tables(self) -> None: - """Create all tables if they do not exist.""" - async with self._pool as conn: - await conn.execute(_INIT_SQL) + def get_session(self) -> AsyncSession: + """Return a new async session from the factory.""" + return self._session_factory() async def insert_candle(self, candle: Candle) -> None: - """Insert a candle row, ignoring duplicates.""" - sql = """ - INSERT INTO candles (symbol, timeframe, open_time, open, high, low, close, volume) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT DO NOTHING - """ - async with self._pool as conn: - await conn.execute( - sql, - candle.symbol, - candle.timeframe, - candle.open_time, - candle.open, - candle.high, - candle.low, - candle.close, - candle.volume, - ) + """Upsert a candle row using session.merge.""" + row = CandleRow( + symbol=candle.symbol, + timeframe=candle.timeframe, + open_time=candle.open_time, + open=candle.open, + high=candle.high, + low=candle.low, + close=candle.close, + volume=candle.volume, + ) + async with self._session_factory() as session: + session.merge(row) + await session.commit() async def insert_signal(self, signal: Signal) -> None: """Insert a signal row.""" - sql = """ - INSERT INTO signals (id, strategy, symbol, side, price, quantity, reason, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """ - async with self._pool as conn: - await conn.execute( - sql, - signal.id, - signal.strategy, - signal.symbol, - signal.side.value, - signal.price, - signal.quantity, - signal.reason, - signal.created_at, - ) + row = SignalRow( + id=signal.id, + strategy=signal.strategy, + symbol=signal.symbol, + side=signal.side.value, + price=signal.price, + quantity=signal.quantity, + reason=signal.reason, + created_at=signal.created_at, + ) + async with self._session_factory() as session: + session.add(row) + await session.commit() async def insert_order(self, order: Order) -> None: """Insert an order row.""" - sql = """ - INSERT INTO orders (id, signal_id, symbol, side, type, price, quantity, status, created_at, filled_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - """ - async with self._pool as conn: - await conn.execute( - sql, - order.id, - order.signal_id, - order.symbol, - order.side.value, - order.type.value, - order.price, - order.quantity, - order.status.value, - order.created_at, - order.filled_at, - ) + row = OrderRow( + id=order.id, + signal_id=order.signal_id, + symbol=order.symbol, + side=order.side.value, + type=order.type.value, + price=order.price, + quantity=order.quantity, + status=order.status.value, + created_at=order.created_at, + filled_at=order.filled_at, + ) + async with self._session_factory() as session: + session.add(row) + await session.commit() async def update_order_status( self, @@ -162,23 +98,26 @@ class Database: filled_at: Optional[datetime] = None, ) -> None: """Update the status (and optionally filled_at) of an order.""" - sql = """ - UPDATE orders SET status = $2, filled_at = $3 WHERE id = $1 - """ - async with self._pool as conn: - await conn.execute(sql, order_id, status.value, filled_at) + stmt = ( + update(OrderRow) + .where(OrderRow.id == order_id) + .values(status=status.value, filled_at=filled_at) + ) + async with self._session_factory() as session: + await session.execute(stmt) + await session.commit() async def get_candles( self, symbol: str, timeframe: str, limit: int = 500 ) -> list[dict]: """Retrieve candles ordered by open_time descending.""" - sql = """ - SELECT symbol, timeframe, open_time, open, high, low, close, volume - FROM candles - WHERE symbol = $1 AND timeframe = $2 - ORDER BY open_time DESC - LIMIT $3 - """ - async with self._pool as conn: - rows = await conn.fetch(sql, symbol, timeframe, limit) - return [dict(row) for row in rows] + stmt = ( + select(CandleRow) + .where(CandleRow.symbol == symbol, CandleRow.timeframe == timeframe) + .order_by(CandleRow.open_time.desc()) + .limit(limit) + ) + async with self._session_factory() as session: + result = await session.execute(stmt) + rows = result.fetchall() + return [dict(row._mapping) for row in rows] -- cgit v1.2.3