summaryrefslogtreecommitdiff
path: root/services/api/tests/test_portfolio_router.py
blob: 099392383672116789b971f0f13e101575b8330f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Tests for portfolio API router."""
import pytest
from decimal import Decimal
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from fastapi import FastAPI

from trading_api.routers.portfolio import router


@pytest.fixture
def app():
    app = FastAPI()
    app.include_router(router, prefix="/portfolio")
    return app


@pytest.fixture
def mock_db():
    db = AsyncMock()
    mock_session = AsyncMock()
    mock_session.__aenter__ = AsyncMock(return_value=mock_session)
    mock_session.__aexit__ = AsyncMock(return_value=False)
    db.get_session = MagicMock(return_value=mock_session)
    return db, mock_session


def test_get_positions_empty(app, mock_db):
    db, session = mock_db
    app.state.db = db

    mock_result = MagicMock()
    mock_result.scalars.return_value.all.return_value = []
    session.execute = AsyncMock(return_value=mock_result)

    client = TestClient(app)
    response = client.get("/portfolio/positions")
    assert response.status_code == 200
    assert response.json() == []


def test_get_positions_with_data(app, mock_db):
    db, session = mock_db
    app.state.db = db

    mock_row = MagicMock()
    mock_row.symbol = "BTCUSDT"
    mock_row.quantity = Decimal("0.1")
    mock_row.avg_entry_price = Decimal("50000")
    mock_row.current_price = Decimal("55000")

    mock_result = MagicMock()
    mock_result.scalars.return_value.all.return_value = [mock_row]
    session.execute = AsyncMock(return_value=mock_result)

    client = TestClient(app)
    response = client.get("/portfolio/positions")
    assert response.status_code == 200
    data = response.json()
    assert len(data) == 1
    assert data[0]["symbol"] == "BTCUSDT"


def test_get_snapshots_empty(app, mock_db):
    db, _ = mock_db
    app.state.db = db
    db.get_portfolio_snapshots = AsyncMock(return_value=[])

    client = TestClient(app)
    response = client.get("/portfolio/snapshots")
    assert response.status_code == 200
    assert response.json() == []