diff options
Diffstat (limited to 'shared/tests/test_db.py')
| -rw-r--r-- | shared/tests/test_db.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 239ee64..439a66e 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -101,6 +101,54 @@ class TestDatabaseConnect: mock_create.assert_called_once() @pytest.mark.asyncio + async def test_connect_passes_pool_params_for_postgres(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800) + mock_create.assert_called_once_with( + "postgresql+asyncpg://host/db", + pool_pre_ping=True, + pool_size=5, + max_overflow=3, + pool_recycle=1800, + ) + + @pytest.mark.asyncio + async def test_connect_skips_pool_params_for_sqlite(self): + from shared.db import Database + + db = Database("sqlite+aiosqlite:///test.db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db") + + @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database |
