blob: 099392383672116789b971f0f13e101575b8330f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
"""Tests for portfolio API router."""
import pytest
from decimal import Decimal
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from fastapi import FastAPI
from trading_api.routers.portfolio import router
@pytest.fixture
def app():
app = FastAPI()
app.include_router(router, prefix="/portfolio")
return app
@pytest.fixture
def mock_db():
db = AsyncMock()
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
db.get_session = MagicMock(return_value=mock_session)
return db, mock_session
def test_get_positions_empty(app, mock_db):
db, session = mock_db
app.state.db = db
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=mock_result)
client = TestClient(app)
response = client.get("/portfolio/positions")
assert response.status_code == 200
assert response.json() == []
def test_get_positions_with_data(app, mock_db):
db, session = mock_db
app.state.db = db
mock_row = MagicMock()
mock_row.symbol = "BTCUSDT"
mock_row.quantity = Decimal("0.1")
mock_row.avg_entry_price = Decimal("50000")
mock_row.current_price = Decimal("55000")
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_row]
session.execute = AsyncMock(return_value=mock_result)
client = TestClient(app)
response = client.get("/portfolio/positions")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["symbol"] == "BTCUSDT"
def test_get_snapshots_empty(app, mock_db):
db, _ = mock_db
app.state.db = db
db.get_portfolio_snapshots = AsyncMock(return_value=[])
client = TestClient(app)
response = client.get("/portfolio/snapshots")
assert response.status_code == 200
assert response.json() == []
|