diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 17:04:38 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 17:04:38 +0900 |
| commit | b4624c77de2ea615a65c04a39d657a38ff2a7c95 (patch) | |
| tree | 918be60fb38d95267f85aa5fe71aa972dc886652 /cli | |
| parent | 32bc579c3f15694308992690f4457524e8842f99 (diff) | |
feat(cli): implement backtest, strategy, portfolio, and data commands
Diffstat (limited to 'cli')
| -rw-r--r-- | cli/src/trading_cli/commands/backtest.py | 96 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/data.py | 144 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/portfolio.py | 110 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/strategy.py | 68 |
4 files changed, 400 insertions, 18 deletions
diff --git a/cli/src/trading_cli/commands/backtest.py b/cli/src/trading_cli/commands/backtest.py index 0f0cdbe..b9e3c1b 100644 --- a/cli/src/trading_cli/commands/backtest.py +++ b/cli/src/trading_cli/commands/backtest.py @@ -1,5 +1,16 @@ +import asyncio +import sys +from decimal import Decimal +from pathlib import Path + import click +# Add service source paths so we can import strategy-engine and backtester +_ROOT = Path(__file__).resolve().parents[5] +sys.path.insert(0, str(_ROOT / "services" / "strategy-engine" / "src")) +sys.path.insert(0, str(_ROOT / "services" / "strategy-engine")) +sys.path.insert(0, str(_ROOT / "services" / "backtester" / "src")) + @click.group() def backtest(): @@ -10,15 +21,85 @@ def backtest(): @backtest.command() @click.option("--strategy", required=True, help="Strategy name to backtest") @click.option("--symbol", required=True, help="Trading symbol (e.g. BTCUSDT)") -@click.option("--from", "from_date", required=True, help="Start date (ISO format)") -@click.option("--to", "to_date", default=None, help="End date (ISO format, defaults to now)") +@click.option("--timeframe", default="1h", show_default=True, help="Candle timeframe") @click.option("--balance", default=10000.0, show_default=True, help="Initial balance in USDT") -def run(strategy, symbol, from_date, to_date, balance): +@click.option( + "--output", + "output_format", + type=click.Choice(["text", "csv", "json"]), + default="text", + show_default=True, + help="Output format", +) +@click.option("--file", "file_path", default=None, help="Save output to file") +def run(strategy, symbol, timeframe, balance, output_format, file_path): """Run a backtest for a strategy.""" - to_label = to_date or "now" - click.echo( - f"Running backtest: strategy={strategy}, symbol={symbol}, {from_date} → {to_label}, balance={balance}..." - ) + try: + from strategy_engine.plugin_loader import load_strategies + from backtester.engine import BacktestEngine + from backtester.reporter import format_report, export_csv, export_json + from shared.db import Database + from shared.config import Settings + from shared.models import Candle + except ImportError as e: + click.echo(f"Error: Could not import required modules: {e}", err=True) + sys.exit(1) + + strategies_dir = _ROOT / "services" / "strategy-engine" / "strategies" + strategies = load_strategies(strategies_dir) + + matched = [s for s in strategies if s.name == strategy] + if not matched: + available = [s.name for s in strategies] + click.echo(f"Error: Strategy '{strategy}' not found. Available: {available}", err=True) + sys.exit(1) + + strat = matched[0] + + async def _run(): + settings = Settings() + db = Database(settings.database_url) + await db.connect() + try: + candle_rows = await db.get_candles(symbol, timeframe, limit=500) + if not candle_rows: + click.echo(f"Error: No candles found for {symbol} {timeframe}", err=True) + sys.exit(1) + + candles = [] + for row in reversed(candle_rows): # get_candles returns DESC, we need ASC + candles.append( + Candle( + symbol=row["symbol"], + timeframe=row["timeframe"], + open_time=row["open_time"], + open=row["open"], + high=row["high"], + low=row["low"], + close=row["close"], + volume=row["volume"], + ) + ) + + engine = BacktestEngine(strat, Decimal(str(balance))) + result = engine.run(candles) + + if output_format == "csv": + output = export_csv(result) + elif output_format == "json": + output = export_json(result) + else: + output = format_report(result) + + if file_path: + Path(file_path).write_text(output) + click.echo(f"Report saved to {file_path}") + else: + click.echo(output) + finally: + await db.close() + + asyncio.run(_run()) @backtest.command() @@ -26,3 +107,4 @@ def run(strategy, symbol, from_date, to_date, balance): def report(backtest_id): """Show a backtest report by ID.""" click.echo(f"Showing backtest report for ID: {backtest_id}...") + click.echo("(Not yet implemented - requires stored backtest results)") diff --git a/cli/src/trading_cli/commands/data.py b/cli/src/trading_cli/commands/data.py index 25d1693..5c6f274 100644 --- a/cli/src/trading_cli/commands/data.py +++ b/cli/src/trading_cli/commands/data.py @@ -1,4 +1,12 @@ +import asyncio +import sys +from pathlib import Path + import click +from rich.console import Console +from rich.table import Table + +_ROOT = Path(__file__).resolve().parents[5] @click.group() @@ -12,7 +20,16 @@ def data(): @click.option("--timeframe", default="1m", show_default=True, help="Candle timeframe") def collect(symbol, timeframe): """Start collecting live market data for a symbol.""" - click.echo(f"Starting data collection for {symbol} at {timeframe} timeframe...") + click.echo( + f"To collect live data for {symbol} at {timeframe}, run the data-collector service:" + ) + click.echo() + click.echo(" docker compose up -d data-collector") + click.echo() + click.echo("Or run directly:") + click.echo() + click.echo(f" cd {_ROOT / 'services' / 'data-collector'}") + click.echo(" python -m data_collector.main") @data.command() @@ -22,14 +39,127 @@ def collect(symbol, timeframe): @click.option("--limit", default=1000, show_default=True, help="Number of candles to fetch") def history(symbol, timeframe, since, limit): """Download historical market data for a symbol.""" - click.echo( - f"Downloading {limit} {timeframe} candles for {symbol}" - + (f" since {since}" if since else "") - + "..." - ) + sys.path.insert(0, str(_ROOT / "services" / "data-collector" / "src")) + + try: + from data_collector.binance_rest import fetch_historical_candles + from shared.db import Database + from shared.config import Settings + except ImportError as e: + click.echo(f"Error: Could not import required modules: {e}", err=True) + sys.exit(1) + + async def _fetch(): + import ccxt.async_support as ccxt + from datetime import datetime, timezone + + settings = Settings() + db = Database(settings.database_url) + await db.connect() + + # Parse the since date to a timestamp in ms + if since: + try: + dt = datetime.fromisoformat(since).replace(tzinfo=timezone.utc) + since_ms = int(dt.timestamp() * 1000) + except ValueError: + click.echo(f"Error: Invalid date format '{since}'. Use ISO format (e.g. 2024-01-01).", err=True) + sys.exit(1) + else: + # Default: fetch from 1000 candles ago (approximate) + since_ms = None + + # Normalize symbol for ccxt (BTCUSDT -> BTC/USDT) + ccxt_symbol = symbol + if "/" not in symbol and "USDT" in symbol: + base = symbol.replace("USDT", "") + ccxt_symbol = f"{base}/USDT" + + exchange = ccxt.binance({ + "apiKey": settings.binance_api_key, + "secret": settings.binance_api_secret, + }) + + try: + kwargs = {"limit": limit} + if since_ms is not None: + kwargs["since"] = since_ms + + candles = await fetch_historical_candles( + exchange, ccxt_symbol, timeframe, **kwargs + ) + + count = 0 + for candle in candles: + await db.insert_candle(candle) + count += 1 + + click.echo(f"Saved {count} candles for {symbol} ({timeframe}) to database.") + except Exception as e: + click.echo(f"Error fetching candles: {e}", err=True) + sys.exit(1) + finally: + await exchange.close() + await db.close() + + asyncio.run(_fetch()) @data.command("list") def list_(): """List available data streams and symbols.""" - click.echo("Fetching available data streams and collected symbols...") + try: + from shared.db import Database + from shared.config import Settings + from shared.sa_models import CandleRow + from sqlalchemy import select, func + except ImportError as e: + click.echo(f"Error: Could not import required modules: {e}", err=True) + sys.exit(1) + + async def _list(): + settings = Settings() + db = Database(settings.database_url) + await db.connect() + try: + stmt = ( + select( + CandleRow.symbol, + CandleRow.timeframe, + func.count().label("count"), + func.min(CandleRow.open_time).label("earliest"), + func.max(CandleRow.open_time).label("latest"), + ) + .group_by(CandleRow.symbol, CandleRow.timeframe) + .order_by(CandleRow.symbol, CandleRow.timeframe) + ) + async with db.get_session() as session: + result = await session.execute(stmt) + rows = result.all() + + if not rows: + click.echo("No data collected yet.") + return + + console = Console() + table = Table(title="Collected Data", show_header=True, header_style="bold cyan") + table.add_column("Symbol", style="bold") + table.add_column("Timeframe") + table.add_column("Candles", justify="right") + table.add_column("From") + table.add_column("To") + + for row in rows: + table.add_row( + row.symbol, + row.timeframe, + str(row.count), + row.earliest.strftime("%Y-%m-%d %H:%M") if row.earliest else "-", + row.latest.strftime("%Y-%m-%d %H:%M") if row.latest else "-", + ) + + console.print(table) + finally: + await db.close() + + asyncio.run(_list()) diff --git a/cli/src/trading_cli/commands/portfolio.py b/cli/src/trading_cli/commands/portfolio.py index 9389bac..ad9a6b4 100644 --- a/cli/src/trading_cli/commands/portfolio.py +++ b/cli/src/trading_cli/commands/portfolio.py @@ -1,4 +1,10 @@ +import asyncio +import sys +from datetime import datetime, timedelta, timezone + import click +from rich.console import Console +from rich.table import Table @click.group() @@ -10,11 +16,111 @@ def portfolio(): @portfolio.command() def show(): """Show the current portfolio holdings and balances.""" - click.echo("Fetching current portfolio...") + try: + from shared.db import Database + from shared.config import Settings + from shared.sa_models import PositionRow + from sqlalchemy import select + except ImportError as e: + click.echo(f"Error: Could not import required modules: {e}", err=True) + sys.exit(1) + + async def _show(): + settings = Settings() + db = Database(settings.database_url) + await db.connect() + try: + async with db.get_session() as session: + result = await session.execute(select(PositionRow)) + rows = result.scalars().all() + + if not rows: + click.echo("No open positions.") + return + + console = Console() + table = Table(title="Current Positions", show_header=True, header_style="bold cyan") + table.add_column("Symbol", style="bold") + table.add_column("Quantity", justify="right") + table.add_column("Avg Entry", justify="right") + table.add_column("Current Price", justify="right") + table.add_column("Unrealized PnL", justify="right") + table.add_column("Updated At") + + for row in rows: + pnl = row.quantity * (row.current_price - row.avg_entry_price) + pnl_style = "green" if pnl >= 0 else "red" + table.add_row( + row.symbol, + f"{row.quantity:.6f}", + f"{row.avg_entry_price:.2f}", + f"{row.current_price:.2f}", + f"[{pnl_style}]{pnl:.2f}[/{pnl_style}]", + row.updated_at.strftime("%Y-%m-%d %H:%M:%S") if row.updated_at else "-", + ) + + console.print(table) + finally: + await db.close() + + asyncio.run(_show()) @portfolio.command() @click.option("--days", default=30, show_default=True, help="Number of days of history") def history(days): """Show PnL history for the portfolio.""" - click.echo(f"Fetching PnL history for the last {days} days...") + try: + from shared.db import Database + from shared.config import Settings + from shared.sa_models import PortfolioSnapshotRow + from sqlalchemy import select + except ImportError as e: + click.echo(f"Error: Could not import required modules: {e}", err=True) + sys.exit(1) + + async def _history(): + settings = Settings() + db = Database(settings.database_url) + await db.connect() + try: + since = datetime.now(timezone.utc) - timedelta(days=days) + stmt = ( + select(PortfolioSnapshotRow) + .where(PortfolioSnapshotRow.snapshot_at >= since) + .order_by(PortfolioSnapshotRow.snapshot_at.asc()) + ) + async with db.get_session() as session: + result = await session.execute(stmt) + rows = result.scalars().all() + + if not rows: + click.echo(f"No portfolio snapshots found in the last {days} days.") + return + + console = Console() + table = Table( + title=f"Portfolio History (last {days} days)", + show_header=True, + header_style="bold cyan", + ) + table.add_column("Date", style="bold") + table.add_column("Total Value", justify="right") + table.add_column("Realized PnL", justify="right") + table.add_column("Unrealized PnL", justify="right") + + for row in rows: + r_style = "green" if row.realized_pnl >= 0 else "red" + u_style = "green" if row.unrealized_pnl >= 0 else "red" + table.add_row( + row.snapshot_at.strftime("%Y-%m-%d %H:%M"), + f"{row.total_value:.2f}", + f"[{r_style}]{row.realized_pnl:.2f}[/{r_style}]", + f"[{u_style}]{row.unrealized_pnl:.2f}[/{u_style}]", + ) + + console.print(table) + finally: + await db.close() + + asyncio.run(_history()) diff --git a/cli/src/trading_cli/commands/strategy.py b/cli/src/trading_cli/commands/strategy.py index 68ffeee..ea3f034 100644 --- a/cli/src/trading_cli/commands/strategy.py +++ b/cli/src/trading_cli/commands/strategy.py @@ -1,4 +1,26 @@ +import sys +from pathlib import Path + import click +from rich.console import Console +from rich.table import Table + +# Add strategy-engine to sys.path +_ROOT = Path(__file__).resolve().parents[5] +sys.path.insert(0, str(_ROOT / "services" / "strategy-engine" / "src")) +sys.path.insert(0, str(_ROOT / "services" / "strategy-engine")) + + +def _load_all_strategies(): + """Load all strategies from the strategies directory.""" + try: + from strategy_engine.plugin_loader import load_strategies + except ImportError as e: + click.echo(f"Error: Could not import plugin_loader: {e}", err=True) + sys.exit(1) + + strategies_dir = _ROOT / "services" / "strategy-engine" / "strategies" + return load_strategies(strategies_dir) @click.group() @@ -10,11 +32,53 @@ def strategy(): @strategy.command("list") def list_(): """List all available trading strategies.""" - click.echo("Fetching available strategies...") + strategies = _load_all_strategies() + + if not strategies: + click.echo("No strategies found.") + return + + console = Console() + table = Table(title="Available Strategies", show_header=True, header_style="bold cyan") + table.add_column("Name", style="bold") + table.add_column("Warmup Period", justify="right") + + for s in strategies: + table.add_row(s.name, str(s.warmup_period)) + + console.print(table) @strategy.command() @click.option("--name", required=True, help="Strategy name") def info(name): """Show detailed information about a strategy.""" - click.echo(f"Fetching details for strategy: {name}...") + strategies = _load_all_strategies() + + matched = [s for s in strategies if s.name == name] + if not matched: + available = [s.name for s in strategies] + click.echo(f"Error: Strategy '{name}' not found. Available: {available}", err=True) + sys.exit(1) + + strat = matched[0] + config_dir = _ROOT / "services" / "strategy-engine" / "strategies" / "config" + config_file = config_dir / f"{name}_strategy.yaml" + if not config_file.exists(): + # Try without _strategy suffix + config_file = config_dir / f"{name}.yaml" + + console = Console() + table = Table(title=f"Strategy: {strat.name}", show_header=True, header_style="bold cyan") + table.add_column("Property", style="bold") + table.add_column("Value") + + table.add_row("Name", strat.name) + table.add_row("Warmup Period", str(strat.warmup_period)) + table.add_row("Class", type(strat).__name__) + if config_file.exists(): + table.add_row("Config File", str(config_file)) + else: + table.add_row("Config File", "(none)") + + console.print(table) |
