diff options
Diffstat (limited to 'shared/tests/test_db.py')
| -rw-r--r-- | shared/tests/test_db.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 04efe9e..f4cabfd 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -335,3 +335,82 @@ class TestTransactionContextManager: mock_session.rollback.assert_awaited_once() mock_session.commit.assert_not_awaited() + + +class TestInsertPortfolioSnapshot: + @pytest.mark.asyncio + async def test_insert_portfolio_snapshot_uses_add_and_commit(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + await db.insert_portfolio_snapshot( + total_value=Decimal("10000"), + realized_pnl=Decimal("0"), + unrealized_pnl=Decimal("500"), + ) + + mock_session.add.assert_called_once() + mock_session.commit.assert_awaited_once() + + @pytest.mark.asyncio + async def test_insert_portfolio_snapshot_rollback_on_error(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + with pytest.raises(RuntimeError, match="db error"): + await db.insert_portfolio_snapshot( + total_value=Decimal("10000"), + realized_pnl=Decimal("0"), + unrealized_pnl=Decimal("500"), + ) + + mock_session.rollback.assert_awaited_once() + + +class TestGetPortfolioSnapshots: + @pytest.mark.asyncio + async def test_get_portfolio_snapshots_returns_list_of_dicts(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_row = MagicMock() + mock_row.total_value = Decimal("10000") + mock_row.realized_pnl = Decimal("0") + mock_row.unrealized_pnl = Decimal("500") + mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_row] + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + result = await db.get_portfolio_snapshots(days=30) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["total_value"] == Decimal("10000") + assert result[0]["unrealized_pnl"] == Decimal("500") + mock_session.execute.assert_awaited_once() |
