diff options
Diffstat (limited to 'shared/tests/test_db.py')
| -rw-r--r-- | shared/tests/test_db.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py new file mode 100644 index 0000000..c31e487 --- /dev/null +++ b/shared/tests/test_db.py @@ -0,0 +1,70 @@ +"""Tests for the database layer.""" +import pytest +from decimal import Decimal +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch, call + + +def make_candle(): + from shared.models import Candle + return Candle( + symbol="BTCUSDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal("50000"), + high=Decimal("51000"), + low=Decimal("49500"), + close=Decimal("50500"), + volume=Decimal("100"), + ) + + +@pytest.mark.asyncio +async def test_db_init_sql_creates_tables(): + """Verify that init_tables SQL references all required table names.""" + with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool: + mock_conn = AsyncMock() + mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.return_value.__aexit__ = AsyncMock(return_value=False) + + # Capture the SQL that gets executed + executed_sqls = [] + + async def capture_execute(sql, *args, **kwargs): + executed_sqls.append(sql) + + mock_conn.execute = capture_execute + + from shared.db import Database + db = Database("postgresql://trading:trading@localhost:5432/trading") + db._pool = mock_pool.return_value + await db.init_tables() + + combined_sql = " ".join(executed_sqls) + for table in ["candles", "signals", "orders", "trades", "positions", "portfolio_snapshots"]: + assert table in combined_sql, f"Table '{table}' not found in SQL" + + +@pytest.mark.asyncio +async def test_db_insert_candle(): + """Verify that insert_candle executes INSERT INTO candles.""" + with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool: + mock_conn = AsyncMock() + mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.return_value.__aexit__ = AsyncMock(return_value=False) + + executed = [] + + async def capture_execute(sql, *args, **kwargs): + executed.append((sql, args)) + + mock_conn.execute = capture_execute + + from shared.db import Database + db = Database("postgresql://trading:trading@localhost:5432/trading") + db._pool = mock_pool.return_value + candle = make_candle() + await db.insert_candle(candle) + + assert any("INSERT INTO candles" in sql for sql, _ in executed), \ + "Expected INSERT INTO candles" |
