summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-02 15:32:37 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-02 15:32:37 +0900
commit78d48f99f93b6738f4239c3f4bb718d3aa5cbb56 (patch)
tree419a66a59c8032321b2982155f1a17efd5355deb /shared
parent0e7fd5059e8a813ccffe2c376b1ff43898b4d966 (diff)
feat: add DB connection pooling with configurable pool_size, overflow, recycle
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/config.py3
-rw-r--r--shared/src/shared/db.py19
-rw-r--r--shared/tests/test_db.py48
3 files changed, 68 insertions, 2 deletions
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
index b6ccebd..b6b9d69 100644
--- a/shared/src/shared/config.py
+++ b/shared/src/shared/config.py
@@ -9,6 +9,9 @@ class Settings(BaseSettings):
alpaca_paper: bool = True # Use paper trading by default
redis_url: str = "redis://localhost:6379"
database_url: str = "postgresql://trading:trading@localhost:5432/trading"
+ db_pool_size: int = 20
+ db_max_overflow: int = 10
+ db_pool_recycle: int = 3600
log_level: str = "INFO"
risk_max_position_size: float = 0.1
risk_stop_loss_pct: float = 5.0
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 9cc8686..e7cad92 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -36,9 +36,24 @@ class Database:
self._engine = None
self._session_factory = None
- async def connect(self) -> None:
+ async def connect(
+ self,
+ pool_size: int = 20,
+ max_overflow: int = 10,
+ pool_recycle: int = 3600,
+ ) -> None:
"""Create the async engine, session factory, and all tables."""
- self._engine = create_async_engine(self._database_url)
+ if self._database_url.startswith("sqlite"):
+ # SQLite doesn't support pooling options
+ self._engine = create_async_engine(self._database_url)
+ else:
+ self._engine = create_async_engine(
+ self._database_url,
+ pool_pre_ping=True,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ )
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 239ee64..439a66e 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -101,6 +101,54 @@ class TestDatabaseConnect:
mock_create.assert_called_once()
@pytest.mark.asyncio
+ async def test_connect_passes_pool_params_for_postgres(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800)
+ mock_create.assert_called_once_with(
+ "postgresql+asyncpg://host/db",
+ pool_pre_ping=True,
+ pool_size=5,
+ max_overflow=3,
+ pool_recycle=1800,
+ )
+
+ @pytest.mark.asyncio
+ async def test_connect_skips_pool_params_for_sqlite(self):
+ from shared.db import Database
+
+ db = Database("sqlite+aiosqlite:///test.db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect()
+ mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db")
+
+ @pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database