summaryrefslogtreecommitdiff
path: root/shared/tests/test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests/test_db.py')
-rw-r--r--shared/tests/test_db.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index b9a9d56..59bf009 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -254,3 +254,84 @@ class TestGetCandles:
assert len(result) == 1
assert result[0]["symbol"] == "BTCUSDT"
mock_session.execute.assert_awaited_once()
+
+
+class TestRollbackOnError:
+ @pytest.mark.asyncio
+ async def test_insert_candle_rollback_on_error(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.merge = MagicMock(return_value=None)
+ mock_session.commit = AsyncMock(side_effect=RuntimeError("db error"))
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ candle = make_candle()
+ with pytest.raises(RuntimeError, match="db error"):
+ await db.insert_candle(candle)
+
+ mock_session.rollback.assert_awaited_once()
+
+ @pytest.mark.asyncio
+ async def test_insert_signal_rollback_on_error(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.commit = AsyncMock(side_effect=RuntimeError("db error"))
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ signal = make_signal()
+ with pytest.raises(RuntimeError, match="db error"):
+ await db.insert_signal(signal)
+
+ mock_session.rollback.assert_awaited_once()
+
+
+class TestTransactionContextManager:
+ @pytest.mark.asyncio
+ async def test_transaction_context_manager_commits(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ async with db.transaction() as session:
+ session.add(MagicMock())
+
+ mock_session.commit.assert_awaited_once()
+ mock_session.rollback.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_transaction_context_manager_rollback(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ with pytest.raises(ValueError, match="test error"):
+ async with db.transaction() as session:
+ raise ValueError("test error")
+
+ mock_session.rollback.assert_awaited_once()
+ mock_session.commit.assert_not_awaited()