From 2d1530f210f4b4f679a5d3b3597c4815904398a7 Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:11:21 +0900 Subject: fix(shared): add transaction rollback on DB errors --- shared/src/shared/db.py | 52 +++++++++++++++++++++++++------ shared/tests/test_db.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index f9b7f56..515ba2c 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,5 +1,6 @@ """Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" +from contextlib import asynccontextmanager from datetime import datetime from typing import Optional @@ -42,6 +43,17 @@ class Database: """Return a new async session from the factory.""" return self._session_factory() + @asynccontextmanager + async def transaction(self): + """Provide a transactional scope with automatic rollback on error.""" + async with self.get_session() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + async def insert_candle(self, candle: Candle) -> None: """Upsert a candle row using session.merge.""" row = CandleRow( @@ -55,8 +67,12 @@ class Database: volume=candle.volume, ) async with self._session_factory() as session: - session.merge(row) - await session.commit() + try: + session.merge(row) + await session.commit() + except Exception: + await session.rollback() + raise async def insert_signal(self, signal: Signal) -> None: """Insert a signal row.""" @@ -71,8 +87,12 @@ class Database: created_at=signal.created_at, ) async with self._session_factory() as session: - session.add(row) - await session.commit() + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise async def insert_order(self, order: Order) -> None: """Insert an order row.""" @@ -89,8 +109,12 @@ class Database: filled_at=order.filled_at, ) async with self._session_factory() as session: - session.add(row) - await session.commit() + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise async def update_order_status( self, @@ -105,8 +129,12 @@ class Database: .values(status=status.value, filled_at=filled_at) ) async with self._session_factory() as session: - await session.execute(stmt) - await session.commit() + try: + await session.execute(stmt) + await session.commit() + except Exception: + await session.rollback() + raise async def get_candles(self, symbol: str, timeframe: str, limit: int = 500) -> list[dict]: """Retrieve candles ordered by open_time descending.""" @@ -117,6 +145,10 @@ class Database: .limit(limit) ) async with self._session_factory() as session: - result = await session.execute(stmt) - rows = result.fetchall() + try: + result = await session.execute(stmt) + rows = result.fetchall() + except Exception: + await session.rollback() + raise return [dict(row._mapping) for row in rows] diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index b9a9d56..59bf009 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -254,3 +254,84 @@ class TestGetCandles: assert len(result) == 1 assert result[0]["symbol"] == "BTCUSDT" mock_session.execute.assert_awaited_once() + + +class TestRollbackOnError: + @pytest.mark.asyncio + async def test_insert_candle_rollback_on_error(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.merge = MagicMock(return_value=None) + mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + candle = make_candle() + with pytest.raises(RuntimeError, match="db error"): + await db.insert_candle(candle) + + mock_session.rollback.assert_awaited_once() + + @pytest.mark.asyncio + async def test_insert_signal_rollback_on_error(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock(side_effect=RuntimeError("db error")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + signal = make_signal() + with pytest.raises(RuntimeError, match="db error"): + await db.insert_signal(signal) + + mock_session.rollback.assert_awaited_once() + + +class TestTransactionContextManager: + @pytest.mark.asyncio + async def test_transaction_context_manager_commits(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + async with db.transaction() as session: + session.add(MagicMock()) + + mock_session.commit.assert_awaited_once() + mock_session.rollback.assert_not_awaited() + + @pytest.mark.asyncio + async def test_transaction_context_manager_rollback(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + db._session_factory = MagicMock(return_value=mock_session) + + with pytest.raises(ValueError, match="test error"): + async with db.transaction() as session: + raise ValueError("test error") + + mock_session.rollback.assert_awaited_once() + mock_session.commit.assert_not_awaited() -- cgit v1.2.3