diff options
Diffstat (limited to 'services/api/tests/test_portfolio_router.py')
| -rw-r--r-- | services/api/tests/test_portfolio_router.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/services/api/tests/test_portfolio_router.py b/services/api/tests/test_portfolio_router.py new file mode 100644 index 0000000..0993923 --- /dev/null +++ b/services/api/tests/test_portfolio_router.py @@ -0,0 +1,73 @@ +"""Tests for portfolio API router.""" +import pytest +from decimal import Decimal +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi.testclient import TestClient +from fastapi import FastAPI + +from trading_api.routers.portfolio import router + + +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(router, prefix="/portfolio") + return app + + +@pytest.fixture +def mock_db(): + db = AsyncMock() + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + db.get_session = MagicMock(return_value=mock_session) + return db, mock_session + + +def test_get_positions_empty(app, mock_db): + db, session = mock_db + app.state.db = db + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=mock_result) + + client = TestClient(app) + response = client.get("/portfolio/positions") + assert response.status_code == 200 + assert response.json() == [] + + +def test_get_positions_with_data(app, mock_db): + db, session = mock_db + app.state.db = db + + mock_row = MagicMock() + mock_row.symbol = "BTCUSDT" + mock_row.quantity = Decimal("0.1") + mock_row.avg_entry_price = Decimal("50000") + mock_row.current_price = Decimal("55000") + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_row] + session.execute = AsyncMock(return_value=mock_result) + + client = TestClient(app) + response = client.get("/portfolio/positions") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["symbol"] == "BTCUSDT" + + +def test_get_snapshots_empty(app, mock_db): + db, _ = mock_db + app.state.db = db + db.get_portfolio_snapshots = AsyncMock(return_value=[]) + + client = TestClient(app) + response = client.get("/portfolio/snapshots") + assert response.status_code == 200 + assert response.json() == [] |
