diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:04:45 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:04:45 +0900 |
| commit | c1d53dbc173f87fe23e179f21c9a713df6484dae (patch) | |
| tree | 64ba8065e1e30b84b36b782d16c5998ffc9f9322 /shared/tests | |
| parent | 2b1db156c7ea7e0be543ab91813922b95eb043cb (diff) | |
feat: rewrite database layer from asyncpg to SQLAlchemy 2.0 async ORM
Replace raw asyncpg SQL with SQLAlchemy async engine, async_sessionmaker,
and ORM operations. Uses session.merge for candle upserts, session.add
for signal/order inserts, update() for status changes, select() for
queries. Auto-converts postgresql:// URLs to postgresql+asyncpg://.
Keeps init_tables() as backward-compatible alias for connect().
Diffstat (limited to 'shared/tests')
| -rw-r--r-- | shared/tests/test_db.py | 236 |
1 files changed, 203 insertions, 33 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index c31e487..45d5dcd 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,4 +1,4 @@ -"""Tests for the database layer.""" +"""Tests for the SQLAlchemy async database layer.""" import pytest from decimal import Decimal from datetime import datetime, timezone @@ -19,52 +19,222 @@ def make_candle(): ) -@pytest.mark.asyncio -async def test_db_init_sql_creates_tables(): - """Verify that init_tables SQL references all required table names.""" - with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool: - mock_conn = AsyncMock() - mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.return_value.__aexit__ = AsyncMock(return_value=False) +def make_signal(): + from shared.models import Signal, OrderSide + return Signal( + id="sig-1", + strategy="ma_cross", + symbol="BTCUSDT", + side=OrderSide.BUY, + price=Decimal("50000"), + quantity=Decimal("0.1"), + reason="Golden cross", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + +def make_order(): + from shared.models import Order, OrderSide, OrderType, OrderStatus + return Order( + id="ord-1", + signal_id="sig-1", + symbol="BTCUSDT", + side=OrderSide.BUY, + type=OrderType.LIMIT, + price=Decimal("50000"), + quantity=Decimal("0.1"), + status=OrderStatus.PENDING, + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + - # Capture the SQL that gets executed - executed_sqls = [] +class TestDatabaseConstructor: + def test_stores_url(self): + from shared.db import Database + db = Database("postgresql://user:pass@localhost/db") + assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db" - async def capture_execute(sql, *args, **kwargs): - executed_sqls.append(sql) + def test_converts_url_prefix(self): + from shared.db import Database + db = Database("postgresql://host/db") + assert db._database_url.startswith("postgresql+asyncpg://") - mock_conn.execute = capture_execute + def test_keeps_asyncpg_prefix(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + assert db._database_url == "postgresql+asyncpg://host/db" + def test_get_session_exists(self): from shared.db import Database - db = Database("postgresql://trading:trading@localhost:5432/trading") - db._pool = mock_pool.return_value - await db.init_tables() + db = Database("postgresql+asyncpg://host/db") + assert hasattr(db, "get_session") - combined_sql = " ".join(executed_sqls) - for table in ["candles", "signals", "orders", "trades", "positions", "portfolio_snapshots"]: - assert table in combined_sql, f"Table '{table}' not found in SQL" +class TestDatabaseConnect: + @pytest.mark.asyncio + async def test_connect_creates_engine_and_tables(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") -@pytest.mark.asyncio -async def test_db_insert_candle(): - """Verify that insert_candle executes INSERT INTO candles.""" - with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool: mock_conn = AsyncMock() - mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.return_value.__aexit__ = AsyncMock(return_value=False) + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_init_tables_is_alias_for_connect(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn - executed = [] + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() - async def capture_execute(sql, *args, **kwargs): - executed.append((sql, args)) + with patch("shared.db.create_async_engine", return_value=mock_engine): + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.init_tables() + # Should succeed without error (same as connect) + + +class TestDatabaseClose: + @pytest.mark.asyncio + async def test_close_disposes_engine(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + mock_engine = AsyncMock() + db._engine = mock_engine + await db.close() + mock_engine.dispose.assert_awaited_once() - mock_conn.execute = capture_execute +class TestInsertCandle: + @pytest.mark.asyncio + async def test_insert_candle_uses_merge_and_commit(self): from shared.db import Database - db = Database("postgresql://trading:trading@localhost:5432/trading") - db._pool = mock_pool.return_value + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.merge = MagicMock(return_value=None) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + candle = make_candle() await db.insert_candle(candle) - assert any("INSERT INTO candles" in sql for sql, _ in executed), \ - "Expected INSERT INTO candles" + mock_session.merge.assert_called_once() + mock_session.commit.assert_awaited_once() + + +class TestInsertSignal: + @pytest.mark.asyncio + async def test_insert_signal_uses_add_and_commit(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + signal = make_signal() + await db.insert_signal(signal) + + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + +class TestInsertOrder: + @pytest.mark.asyncio + async def test_insert_order_uses_add_and_commit(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + order = make_order() + await db.insert_order(order) + + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + +class TestUpdateOrderStatus: + @pytest.mark.asyncio + async def test_update_order_status_uses_execute_and_commit(self): + from shared.db import Database + from shared.models import OrderStatus + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + filled = datetime(2024, 1, 2, tzinfo=timezone.utc) + await db.update_order_status("ord-1", OrderStatus.FILLED, filled) + + mock_session.execute.assert_awaited_once() + mock_session.commit.assert_awaited_once() + + +class TestGetCandles: + @pytest.mark.asyncio + async def test_get_candles_returns_list_of_dicts(self): + from shared.db import Database + db = Database("postgresql+asyncpg://host/db") + + # Create a mock row that behaves like a SA result row + mock_row = MagicMock() + mock_row._mapping = { + "symbol": "BTCUSDT", + "timeframe": "1m", + "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), + "open": Decimal("50000"), + "high": Decimal("51000"), + "low": Decimal("49500"), + "close": Decimal("50500"), + "volume": Decimal("100"), + } + + mock_result = MagicMock() + mock_result.fetchall.return_value = [mock_row] + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + result = await db.get_candles("BTCUSDT", "1m", 500) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["symbol"] == "BTCUSDT" + mock_session.execute.assert_awaited_once() |
