summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:11:21 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 17:11:21 +0900
commit2d1530f210f4b4f679a5d3b3597c4815904398a7 (patch)
treeff2cda5482c4831a4111b0b1185ebdddec142ec4 /shared
parent76f934f95d3b5cbb96765e7158976e4a4c879fa9 (diff)
fix(shared): add transaction rollback on DB errors
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/db.py52
-rw-r--r--shared/tests/test_db.py81
2 files changed, 123 insertions, 10 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index f9b7f56..515ba2c 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,5 +1,6 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
+from contextlib import asynccontextmanager
from datetime import datetime
from typing import Optional
@@ -42,6 +43,17 @@ class Database:
"""Return a new async session from the factory."""
return self._session_factory()
+ @asynccontextmanager
+ async def transaction(self):
+ """Provide a transactional scope with automatic rollback on error."""
+ async with self.get_session() as session:
+ try:
+ yield session
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
async def insert_candle(self, candle: Candle) -> None:
"""Upsert a candle row using session.merge."""
row = CandleRow(
@@ -55,8 +67,12 @@ class Database:
volume=candle.volume,
)
async with self._session_factory() as session:
- session.merge(row)
- await session.commit()
+ try:
+ session.merge(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
async def insert_signal(self, signal: Signal) -> None:
"""Insert a signal row."""
@@ -71,8 +87,12 @@ class Database:
created_at=signal.created_at,
)
async with self._session_factory() as session:
- session.add(row)
- await session.commit()
+ try:
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
async def insert_order(self, order: Order) -> None:
"""Insert an order row."""
@@ -89,8 +109,12 @@ class Database:
filled_at=order.filled_at,
)
async with self._session_factory() as session:
- session.add(row)
- await session.commit()
+ try:
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
async def update_order_status(
self,
@@ -105,8 +129,12 @@ class Database:
.values(status=status.value, filled_at=filled_at)
)
async with self._session_factory() as session:
- await session.execute(stmt)
- await session.commit()
+ try:
+ await session.execute(stmt)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
async def get_candles(self, symbol: str, timeframe: str, limit: int = 500) -> list[dict]:
"""Retrieve candles ordered by open_time descending."""
@@ -117,6 +145,10 @@ class Database:
.limit(limit)
)
async with self._session_factory() as session:
- result = await session.execute(stmt)
- rows = result.fetchall()
+ try:
+ result = await session.execute(stmt)
+ rows = result.fetchall()
+ except Exception:
+ await session.rollback()
+ raise
return [dict(row._mapping) for row in rows]
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index b9a9d56..59bf009 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -254,3 +254,84 @@ class TestGetCandles:
assert len(result) == 1
assert result[0]["symbol"] == "BTCUSDT"
mock_session.execute.assert_awaited_once()
+
+
+class TestRollbackOnError:
+ @pytest.mark.asyncio
+ async def test_insert_candle_rollback_on_error(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.merge = MagicMock(return_value=None)
+ mock_session.commit = AsyncMock(side_effect=RuntimeError("db error"))
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ candle = make_candle()
+ with pytest.raises(RuntimeError, match="db error"):
+ await db.insert_candle(candle)
+
+ mock_session.rollback.assert_awaited_once()
+
+ @pytest.mark.asyncio
+ async def test_insert_signal_rollback_on_error(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.add = MagicMock()
+ mock_session.commit = AsyncMock(side_effect=RuntimeError("db error"))
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ signal = make_signal()
+ with pytest.raises(RuntimeError, match="db error"):
+ await db.insert_signal(signal)
+
+ mock_session.rollback.assert_awaited_once()
+
+
+class TestTransactionContextManager:
+ @pytest.mark.asyncio
+ async def test_transaction_context_manager_commits(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ async with db.transaction() as session:
+ session.add(MagicMock())
+
+ mock_session.commit.assert_awaited_once()
+ mock_session.rollback.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_transaction_context_manager_rollback(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_session = AsyncMock()
+ mock_session.__aenter__ = AsyncMock(return_value=mock_session)
+ mock_session.__aexit__ = AsyncMock(return_value=False)
+
+ db._session_factory = MagicMock(return_value=mock_session)
+
+ with pytest.raises(ValueError, match="test error"):
+ async with db.transaction() as session:
+ raise ValueError("test error")
+
+ mock_session.rollback.assert_awaited_once()
+ mock_session.commit.assert_not_awaited()