summaryrefslogtreecommitdiff
path: root/shared/tests/test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests/test_db.py')
-rw-r--r--shared/tests/test_db.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 45d5dcd..b9a9d56 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,12 +1,14 @@
"""Tests for the SQLAlchemy async database layer."""
+
import pytest
from decimal import Decimal
from datetime import datetime, timezone
-from unittest.mock import AsyncMock, MagicMock, patch, call
+from unittest.mock import AsyncMock, MagicMock, patch
def make_candle():
from shared.models import Candle
+
return Candle(
symbol="BTCUSDT",
timeframe="1m",
@@ -21,6 +23,7 @@ def make_candle():
def make_signal():
from shared.models import Signal, OrderSide
+
return Signal(
id="sig-1",
strategy="ma_cross",
@@ -35,6 +38,7 @@ def make_signal():
def make_order():
from shared.models import Order, OrderSide, OrderType, OrderStatus
+
return Order(
id="ord-1",
signal_id="sig-1",
@@ -51,21 +55,25 @@ def make_order():
class TestDatabaseConstructor:
def test_stores_url(self):
from shared.db import Database
+
db = Database("postgresql://user:pass@localhost/db")
assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db"
def test_converts_url_prefix(self):
from shared.db import Database
+
db = Database("postgresql://host/db")
assert db._database_url.startswith("postgresql+asyncpg://")
def test_keeps_asyncpg_prefix(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
assert db._database_url == "postgresql+asyncpg://host/db"
def test_get_session_exists(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
assert hasattr(db, "get_session")
@@ -74,6 +82,7 @@ class TestDatabaseConnect:
@pytest.mark.asyncio
async def test_connect_creates_engine_and_tables(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_conn = AsyncMock()
@@ -94,6 +103,7 @@ class TestDatabaseConnect:
@pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_conn = AsyncMock()
@@ -116,6 +126,7 @@ class TestDatabaseClose:
@pytest.mark.asyncio
async def test_close_disposes_engine(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_engine = AsyncMock()
db._engine = mock_engine
@@ -127,6 +138,7 @@ class TestInsertCandle:
@pytest.mark.asyncio
async def test_insert_candle_uses_merge_and_commit(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -147,6 +159,7 @@ class TestInsertSignal:
@pytest.mark.asyncio
async def test_insert_signal_uses_add_and_commit(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -167,6 +180,7 @@ class TestInsertOrder:
@pytest.mark.asyncio
async def test_insert_order_uses_add_and_commit(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -188,6 +202,7 @@ class TestUpdateOrderStatus:
async def test_update_order_status_uses_execute_and_commit(self):
from shared.db import Database
from shared.models import OrderStatus
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -207,6 +222,7 @@ class TestGetCandles:
@pytest.mark.asyncio
async def test_get_candles_returns_list_of_dicts(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
# Create a mock row that behaves like a SA result row