diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 17:11:21 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 17:11:21 +0900 |
| commit | 2d1530f210f4b4f679a5d3b3597c4815904398a7 (patch) | |
| tree | ff2cda5482c4831a4111b0b1185ebdddec142ec4 /shared/tests/test_db.py | |
| parent | 76f934f95d3b5cbb96765e7158976e4a4c879fa9 (diff) | |
fix(shared): add transaction rollback on DB errors
Diffstat (limited to 'shared/tests/test_db.py')
| -rw-r--r-- | shared/tests/test_db.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index b9a9d56..59bf009 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -254,3 +254,84 @@ class TestGetCandles: assert len(result) == 1 assert result[0]["symbol"] == "BTCUSDT" mock_session.execute.assert_awaited_once() + + +class TestRollbackOnError: + @pytest.mark.asyncio + async def test_insert_candle_rollback_on_error(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.merge = MagicMock(return_value=None) + mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + candle = make_candle() + with pytest.raises(RuntimeError, match="db error"): + await db.insert_candle(candle) + + mock_session.rollback.assert_awaited_once() + + @pytest.mark.asyncio + async def test_insert_signal_rollback_on_error(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + signal = make_signal() + with pytest.raises(RuntimeError, match="db error"): + await db.insert_signal(signal) + + mock_session.rollback.assert_awaited_once() + + +class TestTransactionContextManager: + @pytest.mark.asyncio + async def test_transaction_context_manager_commits(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + async with db.transaction() as session: + session.add(MagicMock()) + + mock_session.commit.assert_awaited_once() + mock_session.rollback.assert_not_awaited() + + @pytest.mark.asyncio + async def test_transaction_context_manager_rollback(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + with pytest.raises(ValueError, match="test error"): + async with db.transaction() as session: + raise ValueError("test error") + + mock_session.rollback.assert_awaited_once() + mock_session.commit.assert_not_awaited() |
