diff options
| -rw-r--r-- | shared/src/shared/config.py | 3 | ||||
| -rw-r--r-- | shared/src/shared/db.py | 19 | ||||
| -rw-r--r-- | shared/tests/test_db.py | 48 |
3 files changed, 68 insertions, 2 deletions
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index b6ccebd..b6b9d69 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -9,6 +9,9 @@ class Settings(BaseSettings): alpaca_paper: bool = True # Use paper trading by default redis_url: str = "redis://localhost:6379" database_url: str = "postgresql://trading:trading@localhost:5432/trading" + db_pool_size: int = 20 + db_max_overflow: int = 10 + db_pool_recycle: int = 3600 log_level: str = "INFO" risk_max_position_size: float = 0.1 risk_stop_loss_pct: float = 5.0 diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 9cc8686..e7cad92 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -36,9 +36,24 @@ class Database: self._engine = None self._session_factory = None - async def connect(self) -> None: + async def connect( + self, + pool_size: int = 20, + max_overflow: int = 10, + pool_recycle: int = 3600, + ) -> None: """Create the async engine, session factory, and all tables.""" - self._engine = create_async_engine(self._database_url) + if self._database_url.startswith("sqlite"): + # SQLite doesn't support pooling options + self._engine = create_async_engine(self._database_url) + else: + self._engine = create_async_engine( + self._database_url, + pool_pre_ping=True, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + ) self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) async with self._engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 239ee64..439a66e 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -101,6 +101,54 @@ class TestDatabaseConnect: mock_create.assert_called_once() @pytest.mark.asyncio + async def test_connect_passes_pool_params_for_postgres(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800) + mock_create.assert_called_once_with( + "postgresql+asyncpg://host/db", + pool_pre_ping=True, + pool_size=5, + max_overflow=3, + pool_recycle=1800, + ) + + @pytest.mark.asyncio + async def test_connect_skips_pool_params_for_sqlite(self): + from shared.db import Database + + db = Database("sqlite+aiosqlite:///test.db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db") + + @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database |
