summaryrefslogtreecommitdiff
path: root/cli/src/trading_cli/commands/portfolio.py
blob: fd3ebd6d87652f3bce031c032a83f53d47da9482 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import asyncio
import sys
from datetime import UTC, datetime, timedelta

import click
from rich.console import Console
from rich.table import Table


@click.group()
def portfolio():
    """Portfolio management commands."""
    pass


@portfolio.command()
def show():
    """Show the current portfolio holdings and balances."""
    try:
        from sqlalchemy import select

        from shared.config import Settings
        from shared.db import Database
        from shared.sa_models import PositionRow
    except ImportError as e:
        click.echo(f"Error: Could not import required modules: {e}", err=True)
        sys.exit(1)

    async def _show():
        settings = Settings()
        db = Database(settings.database_url.get_secret_value())
        await db.connect()
        try:
            async with db.get_session() as session:
                result = await session.execute(select(PositionRow))
                rows = result.scalars().all()

            if not rows:
                click.echo("No open positions.")
                return

            console = Console()
            table = Table(title="Current Positions", show_header=True, header_style="bold cyan")
            table.add_column("Symbol", style="bold")
            table.add_column("Quantity", justify="right")
            table.add_column("Avg Entry", justify="right")
            table.add_column("Current Price", justify="right")
            table.add_column("Unrealized PnL", justify="right")
            table.add_column("Updated At")

            for row in rows:
                pnl = row.quantity * (row.current_price - row.avg_entry_price)
                pnl_style = "green" if pnl >= 0 else "red"
                table.add_row(
                    row.symbol,
                    f"{row.quantity:.6f}",
                    f"{row.avg_entry_price:.2f}",
                    f"{row.current_price:.2f}",
                    f"[{pnl_style}]{pnl:.2f}[/{pnl_style}]",
                    row.updated_at.strftime("%Y-%m-%d %H:%M:%S") if row.updated_at else "-",
                )

            console.print(table)
        finally:
            await db.close()

    asyncio.run(_show())


@portfolio.command()
@click.option("--days", default=30, show_default=True, help="Number of days of history")
def history(days):
    """Show PnL history for the portfolio."""
    try:
        from sqlalchemy import select

        from shared.config import Settings
        from shared.db import Database
        from shared.sa_models import PortfolioSnapshotRow
    except ImportError as e:
        click.echo(f"Error: Could not import required modules: {e}", err=True)
        sys.exit(1)

    async def _history():
        settings = Settings()
        db = Database(settings.database_url.get_secret_value())
        await db.connect()
        try:
            since = datetime.now(UTC) - timedelta(days=days)
            stmt = (
                select(PortfolioSnapshotRow)
                .where(PortfolioSnapshotRow.snapshot_at >= since)
                .order_by(PortfolioSnapshotRow.snapshot_at.asc())
            )
            async with db.get_session() as session:
                result = await session.execute(stmt)
                rows = result.scalars().all()

            if not rows:
                click.echo(f"No portfolio snapshots found in the last {days} days.")
                return

            console = Console()
            table = Table(
                title=f"Portfolio History (last {days} days)",
                show_header=True,
                header_style="bold cyan",
            )
            table.add_column("Date", style="bold")
            table.add_column("Total Value", justify="right")
            table.add_column("Realized PnL", justify="right")
            table.add_column("Unrealized PnL", justify="right")

            for row in rows:
                r_style = "green" if row.realized_pnl >= 0 else "red"
                u_style = "green" if row.unrealized_pnl >= 0 else "red"
                table.add_row(
                    row.snapshot_at.strftime("%Y-%m-%d %H:%M"),
                    f"{row.total_value:.2f}",
                    f"[{r_style}]{row.realized_pnl:.2f}[/{r_style}]",
                    f"[{u_style}]{row.unrealized_pnl:.2f}[/{u_style}]",
                )

            console.print(table)
        finally:
            await db.close()

    asyncio.run(_history())