summaryrefslogtreecommitdiff
path: root/shared/tests/test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests/test_db.py')
-rw-r--r--shared/tests/test_db.py69
1 files changed, 59 insertions, 10 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 239ee64..b44a713 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,10 +1,11 @@
"""Tests for the SQLAlchemy async database layer."""
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
+import pytest
+
def make_candle():
from shared.models import Candle
@@ -12,7 +13,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -22,7 +23,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
id="sig-1",
@@ -32,12 +33,12 @@ def make_signal():
price=Decimal("50000"),
quantity=Decimal("0.1"),
reason="Golden cross",
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def make_order():
- from shared.models import Order, OrderSide, OrderType, OrderStatus
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
return Order(
id="ord-1",
@@ -48,7 +49,7 @@ def make_order():
price=Decimal("50000"),
quantity=Decimal("0.1"),
status=OrderStatus.PENDING,
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
@@ -101,6 +102,54 @@ class TestDatabaseConnect:
mock_create.assert_called_once()
@pytest.mark.asyncio
+ async def test_connect_passes_pool_params_for_postgres(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800)
+ mock_create.assert_called_once_with(
+ "postgresql+asyncpg://host/db",
+ pool_pre_ping=True,
+ pool_size=5,
+ max_overflow=3,
+ pool_recycle=1800,
+ )
+
+ @pytest.mark.asyncio
+ async def test_connect_skips_pool_params_for_sqlite(self):
+ from shared.db import Database
+
+ db = Database("sqlite+aiosqlite:///test.db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect()
+ mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db")
+
+ @pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database
@@ -211,7 +260,7 @@ class TestUpdateOrderStatus:
db._session_factory = MagicMock(return_value=mock_session)
- filled = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ filled = datetime(2024, 1, 2, tzinfo=UTC)
await db.update_order_status("ord-1", OrderStatus.FILLED, filled)
mock_session.execute.assert_awaited_once()
@@ -230,7 +279,7 @@ class TestGetCandles:
mock_row._mapping = {
"symbol": "AAPL",
"timeframe": "1m",
- "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc),
+ "open_time": datetime(2024, 1, 1, tzinfo=UTC),
"open": Decimal("50000"),
"high": Decimal("51000"),
"low": Decimal("49500"),
@@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots:
mock_row.total_value = Decimal("10000")
mock_row.realized_pnl = Decimal("0")
mock_row.unrealized_pnl = Decimal("500")
- mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_row]