From c1d53dbc173f87fe23e179f21c9a713df6484dae Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:04:45 +0900 Subject: feat: rewrite database layer from asyncpg to SQLAlchemy 2.0 async ORM Replace raw asyncpg SQL with SQLAlchemy async engine, async_sessionmaker, and ORM operations. Uses session.merge for candle upserts, session.add for signal/order inserts, update() for status changes, select() for queries. Auto-converts postgresql:// URLs to postgresql+asyncpg://. Keeps init_tables() as backward-compatible alias for connect(). --- shared/src/shared/db.py | 235 ++++++++++++++++++----------------------------- shared/tests/test_db.py | 236 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 290 insertions(+), 181 deletions(-) diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 6bddd7c..95e487e 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,159 +1,95 @@ -"""Database layer using asyncpg for the trading platform.""" -from datetime import datetime, timezone +"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" +from datetime import datetime from typing import Optional -import asyncpg +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from shared.models import Candle, Signal, Order, OrderStatus - - -_INIT_SQL = """ -CREATE TABLE IF NOT EXISTS candles ( - symbol TEXT NOT NULL, - timeframe TEXT NOT NULL, - open_time TIMESTAMPTZ NOT NULL, - open NUMERIC NOT NULL, - high NUMERIC NOT NULL, - low NUMERIC NOT NULL, - close NUMERIC NOT NULL, - volume NUMERIC NOT NULL, - PRIMARY KEY (symbol, timeframe, open_time) -); - -CREATE TABLE IF NOT EXISTS signals ( - id TEXT PRIMARY KEY, - strategy TEXT NOT NULL, - symbol TEXT NOT NULL, - side TEXT NOT NULL, - price NUMERIC NOT NULL, - quantity NUMERIC NOT NULL, - reason TEXT, - created_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS orders ( - id TEXT PRIMARY KEY, - signal_id TEXT REFERENCES signals(id), - symbol TEXT NOT NULL, - side TEXT NOT NULL, - type TEXT NOT NULL, - price NUMERIC NOT NULL, - quantity NUMERIC NOT NULL, - status TEXT NOT NULL DEFAULT 'PENDING', - created_at TIMESTAMPTZ NOT NULL, - filled_at TIMESTAMPTZ -); - -CREATE TABLE IF NOT EXISTS trades ( - id TEXT PRIMARY KEY, - order_id TEXT REFERENCES orders(id), - symbol TEXT NOT NULL, - side TEXT NOT NULL, - price NUMERIC NOT NULL, - quantity NUMERIC NOT NULL, - fee NUMERIC NOT NULL DEFAULT 0, - traded_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS positions ( - symbol TEXT PRIMARY KEY, - quantity NUMERIC NOT NULL, - avg_entry_price NUMERIC NOT NULL, - current_price NUMERIC NOT NULL, - updated_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS portfolio_snapshots ( - id SERIAL PRIMARY KEY, - total_value NUMERIC NOT NULL, - realized_pnl NUMERIC NOT NULL, - unrealized_pnl NUMERIC NOT NULL, - snapshot_at TIMESTAMPTZ NOT NULL -); -""" +from shared.sa_models import Base, CandleRow, SignalRow, OrderRow class Database: - """Async database access layer backed by asyncpg connection pool.""" + """Async database access layer backed by SQLAlchemy async engine.""" def __init__(self, database_url: str) -> None: + # Auto-convert postgresql:// to postgresql+asyncpg:// + if database_url.startswith("postgresql://"): + database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1) self._database_url = database_url - self._pool: Optional[asyncpg.Pool] = None + self._engine = None + self._session_factory = None async def connect(self) -> None: - """Create the asyncpg connection pool.""" - self._pool = await asyncpg.create_pool(self._database_url) + """Create the async engine, session factory, and all tables.""" + self._engine = create_async_engine(self._database_url) + self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def init_tables(self) -> None: + """Alias for connect() for backward compatibility.""" + await self.connect() async def close(self) -> None: - """Close the asyncpg connection pool.""" - if self._pool: - await self._pool.close() - self._pool = None + """Dispose of the async engine.""" + if self._engine: + await self._engine.dispose() + self._engine = None - async def init_tables(self) -> None: - """Create all tables if they do not exist.""" - async with self._pool as conn: - await conn.execute(_INIT_SQL) + def get_session(self) -> AsyncSession: + """Return a new async session from the factory.""" + return self._session_factory() async def insert_candle(self, candle: Candle) -> None: - """Insert a candle row, ignoring duplicates.""" - sql = """ - INSERT INTO candles (symbol, timeframe, open_time, open, high, low, close, volume) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT DO NOTHING - """ - async with self._pool as conn: - await conn.execute( - sql, - candle.symbol, - candle.timeframe, - candle.open_time, - candle.open, - candle.high, - candle.low, - candle.close, - candle.volume, - ) + """Upsert a candle row using session.merge.""" + row = CandleRow( + symbol=candle.symbol, + timeframe=candle.timeframe, + open_time=candle.open_time, + open=candle.open, + high=candle.high, + low=candle.low, + close=candle.close, + volume=candle.volume, + ) + async with self._session_factory() as session: + session.merge(row) + await session.commit() async def insert_signal(self, signal: Signal) -> None: """Insert a signal row.""" - sql = """ - INSERT INTO signals (id, strategy, symbol, side, price, quantity, reason, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """ - async with self._pool as conn: - await conn.execute( - sql, - signal.id, - signal.strategy, - signal.symbol, - signal.side.value, - signal.price, - signal.quantity, - signal.reason, - signal.created_at, - ) + row = SignalRow( + id=signal.id, + strategy=signal.strategy, + symbol=signal.symbol, + side=signal.side.value, + price=signal.price, + quantity=signal.quantity, + reason=signal.reason, + created_at=signal.created_at, + ) + async with self._session_factory() as session: + session.add(row) + await session.commit() async def insert_order(self, order: Order) -> None: """Insert an order row.""" - sql = """ - INSERT INTO orders (id, signal_id, symbol, side, type, price, quantity, status, created_at, filled_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - """ - async with self._pool as conn: - await conn.execute( - sql, - order.id, - order.signal_id, - order.symbol, - order.side.value, - order.type.value, - order.price, - order.quantity, - order.status.value, - order.created_at, - order.filled_at, - ) + row = OrderRow( + id=order.id, + signal_id=order.signal_id, + symbol=order.symbol, + side=order.side.value, + type=order.type.value, + price=order.price, + quantity=order.quantity, + status=order.status.value, + created_at=order.created_at, + filled_at=order.filled_at, + ) + async with self._session_factory() as session: + session.add(row) + await session.commit() async def update_order_status( self, @@ -162,23 +98,26 @@ class Database: filled_at: Optional[datetime] = None, ) -> None: """Update the status (and optionally filled_at) of an order.""" - sql = """ - UPDATE orders SET status = $2, filled_at = $3 WHERE id = $1 - """ - async with self._pool as conn: - await conn.execute(sql, order_id, status.value, filled_at) + stmt = ( + update(OrderRow) + .where(OrderRow.id == order_id) + .values(status=status.value, filled_at=filled_at) + ) + async with self._session_factory() as session: + await session.execute(stmt) + await session.commit() async def get_candles( self, symbol: str, timeframe: str, limit: int = 500 ) -> list[dict]: """Retrieve candles ordered by open_time descending.""" - sql = """ - SELECT symbol, timeframe, open_time, open, high, low, close, volume - FROM candles - WHERE symbol = $1 AND timeframe = $2 - ORDER BY open_time DESC - LIMIT $3 - """ - async with self._pool as conn: - rows = await conn.fetch(sql, symbol, timeframe, limit) - return [dict(row) for row in rows] + stmt = ( + select(CandleRow) + .where(CandleRow.symbol == symbol, CandleRow.timeframe == timeframe) + .order_by(CandleRow.open_time.desc()) + .limit(limit) + ) + async with self._session_factory() as session: + result = await session.execute(stmt) + rows = result.fetchall() + return [dict(row._mapping) for row in rows] diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index c31e487..45d5dcd 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,4 +1,4 @@ -"""Tests for the database layer.""" +"""Tests for the SQLAlchemy async database layer.""" import pytest from decimal import Decimal from datetime import datetime, timezone @@ -19,52 +19,222 @@ def make_candle(): ) -@pytest.mark.asyncio -async def test_db_init_sql_creates_tables(): - """Verify that init_tables SQL references all required table names.""" - with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool: - mock_conn = AsyncMock() - mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.return_value.__aexit__ = AsyncMock(return_value=False) +def make_signal(): + from shared.models import Signal, OrderSide + return Signal( + id="sig-1", + strategy="ma_cross", + symbol="BTCUSDT", + side=OrderSide.BUY, + price=Decimal("50000"), + quantity=Decimal("0.1"), + reason="Golden cross", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + +def make_order(): + from shared.models import Order, OrderSide, OrderType, OrderStatus + return Order( + id="ord-1", + signal_id="sig-1", + symbol="BTCUSDT", + side=OrderSide.BUY, + type=OrderType.LIMIT, + price=Decimal("50000"), + quantity=Decimal("0.1"), + status=OrderStatus.PENDING, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + - # Capture the SQL that gets executed - executed_sqls = [] +class TestDatabaseConstructor: + def test_stores_url(self): + from shared.db import Database + db = Database("postgresql://user:pass@localhost/db") + assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db" - async def capture_execute(sql, *args, **kwargs): - executed_sqls.append(sql) + def test_converts_url_prefix(self): + from shared.db import Database + db = Database("postgresql://host/db") + assert db._database_url.startswith("postgresql+asyncpg://") - mock_conn.execute = capture_execute + def test_keeps_asyncpg_prefix(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + assert db._database_url == "postgresql+asyncpg://host/db" + def test_get_session_exists(self): from shared.db import Database - db = Database("postgresql://trading:trading@localhost:5432/trading") - db._pool = mock_pool.return_value - await db.init_tables() + db = Database("postgresql+asyncpg://host/db") + assert hasattr(db, "get_session") - combined_sql = " ".join(executed_sqls) - for table in ["candles", "signals", "orders", "trades", "positions", "portfolio_snapshots"]: - assert table in combined_sql, f"Table '{table}' not found in SQL" +class TestDatabaseConnect: + @pytest.mark.asyncio + async def test_connect_creates_engine_and_tables(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") -@pytest.mark.asyncio -async def test_db_insert_candle(): - """Verify that insert_candle executes INSERT INTO candles.""" - with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool: mock_conn = AsyncMock() - mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.return_value.__aexit__ = AsyncMock(return_value=False) + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_init_tables_is_alias_for_connect(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn - executed = [] + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() - async def capture_execute(sql, *args, **kwargs): - executed.append((sql, args)) + with patch("shared.db.create_async_engine", return_value=mock_engine): + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.init_tables() + # Should succeed without error (same as connect) + + +class TestDatabaseClose: + @pytest.mark.asyncio + async def test_close_disposes_engine(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + mock_engine = AsyncMock() + db._engine = mock_engine + await db.close() + mock_engine.dispose.assert_awaited_once() - mock_conn.execute = capture_execute +class TestInsertCandle: + @pytest.mark.asyncio + async def test_insert_candle_uses_merge_and_commit(self): from shared.db import Database - db = Database("postgresql://trading:trading@localhost:5432/trading") - db._pool = mock_pool.return_value + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.merge = MagicMock(return_value=None) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + candle = make_candle() await db.insert_candle(candle) - assert any("INSERT INTO candles" in sql for sql, _ in executed), \ - "Expected INSERT INTO candles" + mock_session.merge.assert_called_once() + mock_session.commit.assert_awaited_once() + + +class TestInsertSignal: + @pytest.mark.asyncio + async def test_insert_signal_uses_add_and_commit(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + signal = make_signal() + await db.insert_signal(signal) + + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + +class TestInsertOrder: + @pytest.mark.asyncio + async def test_insert_order_uses_add_and_commit(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + order = make_order() + await db.insert_order(order) + + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + +class TestUpdateOrderStatus: + @pytest.mark.asyncio + async def test_update_order_status_uses_execute_and_commit(self): + from shared.db import Database + from shared.models import OrderStatus + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + filled = datetime(2024, 1, 2, tzinfo=timezone.utc) + await db.update_order_status("ord-1", OrderStatus.FILLED, filled) + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + + +class TestGetCandles: + @pytest.mark.asyncio + async def test_get_candles_returns_list_of_dicts(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + # Create a mock row that behaves like a SA result row + mock_row = MagicMock() + mock_row._mapping = { + "symbol": "BTCUSDT", + "timeframe": "1m", + "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), + "open": Decimal("50000"), + "high": Decimal("51000"), + "low": Decimal("49500"), + "close": Decimal("50500"), + "volume": Decimal("100"), + } + + mock_result = MagicMock() + mock_result.fetchall.return_value = [mock_row] + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + result = await db.get_candles("BTCUSDT", "1m", 500) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["symbol"] == "BTCUSDT" + mock_session.execute.assert_awaited_once() -- cgit v1.2.3