summaryrefslogtreecommitdiff
path: root/shared/tests/test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests/test_db.py')
-rw-r--r--shared/tests/test_db.py236
1 files changed, 203 insertions, 33 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index c31e487..45d5dcd 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,4 +1,4 @@
-"""Tests for the database layer."""
+"""Tests for the SQLAlchemy async database layer."""
import pytest
from decimal import Decimal
from datetime import datetime, timezone
@@ -19,52 +19,222 @@ def make_candle():
)
-@pytest.mark.asyncio
-async def test_db_init_sql_creates_tables():
- """Verify that init_tables SQL references all required table names."""
- with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool:
- mock_conn = AsyncMock()
- mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
- mock_pool.return_value.__aexit__ = AsyncMock(return_value=False)
+def make_signal():
+ from shared.models import Signal, OrderSide
+ return Signal(
+ id="sig-1",
+ strategy="ma_cross",
+ symbol="BTCUSDT",
+ side=OrderSide.BUY,
+ price=Decimal("50000"),
+ quantity=Decimal("0.1"),
+ reason="Golden cross",
+ created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ )
+
+
+def make_order():
+ from shared.models import Order, OrderSide, OrderType, OrderStatus
+ return Order(
+ id="ord-1",
+ signal_id="sig-1",
+ symbol="BTCUSDT",
+ side=OrderSide.BUY,
+ type=OrderType.LIMIT,
+ price=Decimal("50000"),
+ quantity=Decimal("0.1"),
+ status=OrderStatus.PENDING,
+ created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ )
+
- # Capture the SQL that gets executed
- executed_sqls = []
+class TestDatabaseConstructor:
+ def test_stores_url(self):
+ from shared.db import Database
+ db = Database("postgresql://user:pass@localhost/db")
+ assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db"
- async def capture_execute(sql, *args, **kwargs):
- executed_sqls.append(sql)
+ def test_converts_url_prefix(self):
+ from shared.db import Database
+ db = Database("postgresql://host/db")
+ assert db._database_url.startswith("postgresql+asyncpg://")
- mock_conn.execute = capture_execute
+ def test_keeps_asyncpg_prefix(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
+ assert db._database_url == "postgresql+asyncpg://host/db"
+ def test_get_session_exists(self):
from shared.db import Database
- db = Database("postgresql://trading:trading@localhost:5432/trading")
- db._pool = mock_pool.return_value
- await db.init_tables()
+ db = Database("postgresql+asyncpg://host/db")
+ assert hasattr(db, "get_session")
- combined_sql = " ".join(executed_sqls)
- for table in ["candles", "signals", "orders", "trades", "positions", "portfolio_snapshots"]:
- assert table in combined_sql, f"Table '{table}' not found in SQL"
+class TestDatabaseConnect:
+ @pytest.mark.asyncio
+ async def test_connect_creates_engine_and_tables(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
-@pytest.mark.asyncio
-async def test_db_insert_candle():
- """Verify that insert_candle executes INSERT INTO candles."""
- with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool:
mock_conn = AsyncMock()
- mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
- mock_pool.return_value.__aexit__ = AsyncMock(return_value=False)
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect()
+ mock_create.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_init_tables_is_alias_for_connect(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
- executed = []
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
- async def capture_execute(sql, *args, **kwargs):
- executed.append((sql, args))
+ with patch("shared.db.create_async_engine", return_value=mock_engine):
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.init_tables()
+ # Should succeed without error (same as connect)
+
+
+class TestDatabaseClose:
+ @pytest.mark.asyncio
+ async def test_close_disposes_engine(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
+ mock_engine = AsyncMock()
+ db._engine = mock_engine
+ await db.close()
+ mock_engine.dispose.assert_awaited_once()
- mock_conn.execute = capture_execute
+class TestInsertCandle:
+ @pytest.mark.asyncio
+ async def test_insert_candle_uses_merge_and_commit(self):
from shared.db import Database
- db = Database("postgresql://trading:trading@localhost:5432/trading")
- db._pool = mock_pool.return_value
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.merge = MagicMock(return_value=None)
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
candle = make_candle()
await db.insert_candle(candle)
- assert any("INSERT INTO candles" in sql for sql, _ in executed), \
- "Expected INSERT INTO candles"
+ mock_session.merge.assert_called_once()
+ mock_session.commit.assert_awaited_once()
+
+
+class TestInsertSignal:
+ @pytest.mark.asyncio
+ async def test_insert_signal_uses_add_and_commit(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ signal = make_signal()
+ await db.insert_signal(signal)
+
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_awaited_once()
+
+
+class TestInsertOrder:
+ @pytest.mark.asyncio
+ async def test_insert_order_uses_add_and_commit(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ order = make_order()
+ await db.insert_order(order)
+
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_awaited_once()
+
+
+class TestUpdateOrderStatus:
+ @pytest.mark.asyncio
+ async def test_update_order_status_uses_execute_and_commit(self):
+ from shared.db import Database
+ from shared.models import OrderStatus
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ filled = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ await db.update_order_status("ord-1", OrderStatus.FILLED, filled)
+
+ mock_session.execute.assert_awaited_once()
+ mock_session.commit.assert_awaited_once()
+
+
+class TestGetCandles:
+ @pytest.mark.asyncio
+ async def test_get_candles_returns_list_of_dicts(self):
+ from shared.db import Database
+ db = Database("postgresql+asyncpg://host/db")
+
+ # Create a mock row that behaves like a SA result row
+ mock_row = MagicMock()
+ mock_row._mapping = {
+ "symbol": "BTCUSDT",
+ "timeframe": "1m",
+ "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc),
+ "open": Decimal("50000"),
+ "high": Decimal("51000"),
+ "low": Decimal("49500"),
+ "close": Decimal("50500"),
+ "volume": Decimal("100"),
+ }
+
+ mock_result = MagicMock()
+ mock_result.fetchall.return_value = [mock_row]
+
+ mock_session = AsyncMock()
+ mock_session.execute = AsyncMock(return_value=mock_result)
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ result = await db.get_candles("BTCUSDT", "1m", 500)
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ assert result[0]["symbol"] == "BTCUSDT"
+ mock_session.execute.assert_awaited_once()