"""Tests for the SQLAlchemy async database layer.""" import pytest from decimal import Decimal from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch def make_candle(): from shared.models import Candle return Candle( symbol="BTCUSDT", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), close=Decimal("50500"), volume=Decimal("100"), ) def make_signal(): from shared.models import Signal, OrderSide return Signal( id="sig-1", strategy="ma_cross", symbol="BTCUSDT", side=OrderSide.BUY, price=Decimal("50000"), quantity=Decimal("0.1"), reason="Golden cross", created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), ) def make_order(): from shared.models import Order, OrderSide, OrderType, OrderStatus return Order( id="ord-1", signal_id="sig-1", symbol="BTCUSDT", side=OrderSide.BUY, type=OrderType.LIMIT, price=Decimal("50000"), quantity=Decimal("0.1"), status=OrderStatus.PENDING, created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), ) class TestDatabaseConstructor: def test_stores_url(self): from shared.db import Database db = Database("postgresql://user:pass@localhost/db") assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db" def test_converts_url_prefix(self): from shared.db import Database db = Database("postgresql://host/db") assert db._database_url.startswith("postgresql+asyncpg://") def test_keeps_asyncpg_prefix(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") assert db._database_url == "postgresql+asyncpg://host/db" def test_get_session_exists(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") assert hasattr(db, "get_session") class TestDatabaseConnect: @pytest.mark.asyncio async def test_connect_creates_engine_and_tables(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_conn = AsyncMock() mock_cm = AsyncMock() mock_cm.__aenter__.return_value = mock_conn mock_engine = MagicMock() mock_engine.begin.return_value = mock_cm mock_engine.dispose = AsyncMock() with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: with patch("shared.db.async_sessionmaker"): with patch("shared.db.Base") as mock_base: mock_base.metadata.create_all = MagicMock() await db.connect() mock_create.assert_called_once() @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_conn = AsyncMock() mock_cm = AsyncMock() mock_cm.__aenter__.return_value = mock_conn mock_engine = MagicMock() mock_engine.begin.return_value = mock_cm mock_engine.dispose = AsyncMock() with patch("shared.db.create_async_engine", return_value=mock_engine): with patch("shared.db.async_sessionmaker"): with patch("shared.db.Base") as mock_base: mock_base.metadata.create_all = MagicMock() await db.init_tables() # Should succeed without error (same as connect) class TestDatabaseClose: @pytest.mark.asyncio async def test_close_disposes_engine(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_engine = AsyncMock() db._engine = mock_engine await db.close() mock_engine.dispose.assert_awaited_once() class TestInsertCandle: @pytest.mark.asyncio async def test_insert_candle_uses_merge_and_commit(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.merge = MagicMock(return_value=None) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) candle = make_candle() await db.insert_candle(candle) mock_session.merge.assert_called_once() mock_session.commit.assert_awaited_once() class TestInsertSignal: @pytest.mark.asyncio async def test_insert_signal_uses_add_and_commit(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) signal = make_signal() await db.insert_signal(signal) mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() class TestInsertOrder: @pytest.mark.asyncio async def test_insert_order_uses_add_and_commit(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) order = make_order() await db.insert_order(order) mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() class TestUpdateOrderStatus: @pytest.mark.asyncio async def test_update_order_status_uses_execute_and_commit(self): from shared.db import Database from shared.models import OrderStatus db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) filled = datetime(2024, 1, 2, tzinfo=timezone.utc) await db.update_order_status("ord-1", OrderStatus.FILLED, filled) mock_session.execute.assert_awaited_once() mock_session.commit.assert_awaited_once() class TestGetCandles: @pytest.mark.asyncio async def test_get_candles_returns_list_of_dicts(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") # Create a mock row that behaves like a SA result row mock_row = MagicMock() mock_row._mapping = { "symbol": "BTCUSDT", "timeframe": "1m", "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), "open": Decimal("50000"), "high": Decimal("51000"), "low": Decimal("49500"), "close": Decimal("50500"), "volume": Decimal("100"), } mock_result = MagicMock() mock_result.fetchall.return_value = [mock_row] mock_session = AsyncMock() mock_session.execute = AsyncMock(return_value=mock_result) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) result = await db.get_candles("BTCUSDT", "1m", 500) assert isinstance(result, list) assert len(result) == 1 assert result[0]["symbol"] == "BTCUSDT" mock_session.execute.assert_awaited_once() class TestRollbackOnError: @pytest.mark.asyncio async def test_insert_candle_rollback_on_error(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.merge = MagicMock(return_value=None) mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) candle = make_candle() with pytest.raises(RuntimeError, match="db error"): await db.insert_candle(candle) mock_session.rollback.assert_awaited_once() @pytest.mark.asyncio async def test_insert_signal_rollback_on_error(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) signal = make_signal() with pytest.raises(RuntimeError, match="db error"): await db.insert_signal(signal) mock_session.rollback.assert_awaited_once() class TestTransactionContextManager: @pytest.mark.asyncio async def test_transaction_context_manager_commits(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) async with db.transaction() as session: session.add(MagicMock()) mock_session.commit.assert_awaited_once() mock_session.rollback.assert_not_awaited() @pytest.mark.asyncio async def test_transaction_context_manager_rollback(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) with pytest.raises(ValueError, match="test error"): async with db.transaction() as _session: raise ValueError("test error") mock_session.rollback.assert_awaited_once() mock_session.commit.assert_not_awaited() class TestInsertPortfolioSnapshot: @pytest.mark.asyncio async def test_insert_portfolio_snapshot_uses_add_and_commit(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) await db.insert_portfolio_snapshot( total_value=Decimal("10000"), realized_pnl=Decimal("0"), unrealized_pnl=Decimal("500"), ) mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() @pytest.mark.asyncio async def test_insert_portfolio_snapshot_rollback_on_error(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) with pytest.raises(RuntimeError, match="db error"): await db.insert_portfolio_snapshot( total_value=Decimal("10000"), realized_pnl=Decimal("0"), unrealized_pnl=Decimal("500"), ) mock_session.rollback.assert_awaited_once() class TestGetPortfolioSnapshots: @pytest.mark.asyncio async def test_get_portfolio_snapshots_returns_list_of_dicts(self): from shared.db import Database db = Database("postgresql+asyncpg://host/db") mock_row = MagicMock() mock_row.total_value = Decimal("10000") mock_row.realized_pnl = Decimal("0") mock_row.unrealized_pnl = Decimal("500") mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_row] mock_session = AsyncMock() mock_session.execute = AsyncMock(return_value=mock_result) mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aexit__ = AsyncMock(return_value=False) db._session_factory = MagicMock(return_value=mock_session) result = await db.get_portfolio_snapshots(days=30) assert isinstance(result, list) assert len(result) == 1 assert result[0]["total_value"] == Decimal("10000") assert result[0]["unrealized_pnl"] == Decimal("500") mock_session.execute.assert_awaited_once()