summaryrefslogtreecommitdiff
path: root/shared/tests/test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests/test_db.py')
-rw-r--r--shared/tests/test_db.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
new file mode 100644
index 0000000..c31e487
--- /dev/null
+++ b/shared/tests/test_db.py
@@ -0,0 +1,70 @@
+"""Tests for the database layer."""
+import pytest
+from decimal import Decimal
+from datetime import datetime, timezone
+from unittest.mock import AsyncMock, MagicMock, patch, call
+
+
+def make_candle():
+ from shared.models import Candle
+ return Candle(
+ symbol="BTCUSDT",
+ timeframe="1m",
+ open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open=Decimal("50000"),
+ high=Decimal("51000"),
+ low=Decimal("49500"),
+ close=Decimal("50500"),
+ volume=Decimal("100"),
+ )
+
+
+@pytest.mark.asyncio
+async def test_db_init_sql_creates_tables():
+ """Verify that init_tables SQL references all required table names."""
+ with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool:
+ mock_conn = AsyncMock()
+ mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
+ mock_pool.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ # Capture the SQL that gets executed
+ executed_sqls = []
+
+ async def capture_execute(sql, *args, **kwargs):
+ executed_sqls.append(sql)
+
+ mock_conn.execute = capture_execute
+
+ from shared.db import Database
+ db = Database("postgresql://trading:trading@localhost:5432/trading")
+ db._pool = mock_pool.return_value
+ await db.init_tables()
+
+ combined_sql = " ".join(executed_sqls)
+ for table in ["candles", "signals", "orders", "trades", "positions", "portfolio_snapshots"]:
+ assert table in combined_sql, f"Table '{table}' not found in SQL"
+
+
+@pytest.mark.asyncio
+async def test_db_insert_candle():
+ """Verify that insert_candle executes INSERT INTO candles."""
+ with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool:
+ mock_conn = AsyncMock()
+ mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
+ mock_pool.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ executed = []
+
+ async def capture_execute(sql, *args, **kwargs):
+ executed.append((sql, args))
+
+ mock_conn.execute = capture_execute
+
+ from shared.db import Database
+ db = Database("postgresql://trading:trading@localhost:5432/trading")
+ db._pool = mock_pool.return_value
+ candle = make_candle()
+ await db.insert_candle(candle)
+
+ assert any("INSERT INTO candles" in sql for sql, _ in executed), \
+ "Expected INSERT INTO candles"