diff options
Diffstat (limited to 'cli')
| -rw-r--r-- | cli/src/trading_cli/commands/backtest.py | 41 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/data.py | 20 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/portfolio.py | 18 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/service.py | 1 | ||||
| -rw-r--r-- | cli/src/trading_cli/main.py | 7 | ||||
| -rw-r--r-- | cli/tests/test_cli_strategy.py | 3 |
6 files changed, 49 insertions, 41 deletions
diff --git a/cli/src/trading_cli/commands/backtest.py b/cli/src/trading_cli/commands/backtest.py index 3876f1b..c17c61f 100644 --- a/cli/src/trading_cli/commands/backtest.py +++ b/cli/src/trading_cli/commands/backtest.py @@ -35,11 +35,12 @@ def backtest(): def run(strategy, symbol, timeframe, balance, output_format, file_path): """Run a backtest for a strategy.""" try: - from strategy_engine.plugin_loader import load_strategies from backtester.engine import BacktestEngine - from backtester.reporter import format_report, export_csv, export_json - from shared.db import Database + from backtester.reporter import export_csv, export_json, format_report + from strategy_engine.plugin_loader import load_strategies + from shared.config import Settings + from shared.db import Database from shared.models import Candle except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) @@ -58,7 +59,7 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): async def _run(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: candle_rows = await db.get_candles(symbol, timeframe, limit=500) @@ -66,20 +67,19 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): click.echo(f"Error: No candles found for {symbol} {timeframe}", err=True) sys.exit(1) - candles = [] - for row in reversed(candle_rows): # get_candles returns DESC, we need ASC - candles.append( - Candle( - symbol=row["symbol"], - timeframe=row["timeframe"], - open_time=row["open_time"], - open=row["open"], - high=row["high"], - low=row["low"], - close=row["close"], - volume=row["volume"], - ) + candles = [ + Candle( + symbol=row["symbol"], + timeframe=row["timeframe"], + open_time=row["open_time"], + open=row["open"], + high=row["high"], + low=row["low"], + close=row["close"], + volume=row["volume"], ) + for row in reversed(candle_rows) # get_candles returns DESC, we need ASC + ] engine = BacktestEngine(strat, Decimal(str(balance))) result = engine.run(candles) @@ -111,10 +111,11 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): def walk_forward(strategy, symbol, timeframe, balance, windows): """Run walk-forward analysis to detect overfitting.""" try: - from strategy_engine.plugin_loader import load_strategies from backtester.walk_forward import WalkForwardEngine - from shared.db import Database + from strategy_engine.plugin_loader import load_strategies + from shared.config import Settings + from shared.db import Database from shared.models import Candle except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) @@ -131,7 +132,7 @@ def walk_forward(strategy, symbol, timeframe, balance, windows): async def _run(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: rows = await db.get_candles(symbol, timeframe, limit=2000) diff --git a/cli/src/trading_cli/commands/data.py b/cli/src/trading_cli/commands/data.py index 1ecc15f..64639cf 100644 --- a/cli/src/trading_cli/commands/data.py +++ b/cli/src/trading_cli/commands/data.py @@ -1,5 +1,6 @@ import asyncio import sys +from datetime import UTC from pathlib import Path import click @@ -39,23 +40,23 @@ def history(symbol, timeframe, since, limit): """Download historical stock market data for a symbol.""" try: from shared.alpaca_client import AlpacaClient - from shared.db import Database from shared.config import Settings + from shared.db import Database except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _fetch(): - from datetime import datetime, timezone + from datetime import datetime settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() start = None if since: try: - start = datetime.fromisoformat(since).replace(tzinfo=timezone.utc) + start = datetime.fromisoformat(since).replace(tzinfo=UTC) except ValueError: click.echo( f"Error: Invalid date format '{since}'. Use ISO format (e.g. 2024-01-01).", @@ -64,8 +65,8 @@ def history(symbol, timeframe, since, limit): sys.exit(1) client = AlpacaClient( - api_key=settings.alpaca_api_key, - api_secret=settings.alpaca_api_secret, + api_key=settings.alpaca_api_key.get_secret_value(), + api_secret=settings.alpaca_api_secret.get_secret_value(), base_url=getattr(settings, "alpaca_base_url", "https://paper-api.alpaca.markets"), ) @@ -97,17 +98,18 @@ def history(symbol, timeframe, since, limit): def list_(): """List available data streams and symbols.""" try: - from shared.db import Database + from sqlalchemy import func, select + from shared.config import Settings + from shared.db import Database from shared.sa_models import CandleRow - from sqlalchemy import select, func except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _list(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: stmt = ( diff --git a/cli/src/trading_cli/commands/portfolio.py b/cli/src/trading_cli/commands/portfolio.py index ad9a6b4..fd3ebd6 100644 --- a/cli/src/trading_cli/commands/portfolio.py +++ b/cli/src/trading_cli/commands/portfolio.py @@ -1,6 +1,6 @@ import asyncio import sys -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import click from rich.console import Console @@ -17,17 +17,18 @@ def portfolio(): def show(): """Show the current portfolio holdings and balances.""" try: - from shared.db import Database + from sqlalchemy import select + from shared.config import Settings + from shared.db import Database from shared.sa_models import PositionRow - from sqlalchemy import select except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _show(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: async with db.get_session() as session: @@ -71,20 +72,21 @@ def show(): def history(days): """Show PnL history for the portfolio.""" try: - from shared.db import Database + from sqlalchemy import select + from shared.config import Settings + from shared.db import Database from shared.sa_models import PortfolioSnapshotRow - from sqlalchemy import select except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _history(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: - since = datetime.now(timezone.utc) - timedelta(days=days) + since = datetime.now(UTC) - timedelta(days=days) stmt = ( select(PortfolioSnapshotRow) .where(PortfolioSnapshotRow.snapshot_at >= since) diff --git a/cli/src/trading_cli/commands/service.py b/cli/src/trading_cli/commands/service.py index d01eaae..6d02f14 100644 --- a/cli/src/trading_cli/commands/service.py +++ b/cli/src/trading_cli/commands/service.py @@ -1,4 +1,5 @@ import subprocess + import click diff --git a/cli/src/trading_cli/main.py b/cli/src/trading_cli/main.py index 1129bdd..0ed2307 100644 --- a/cli/src/trading_cli/main.py +++ b/cli/src/trading_cli/main.py @@ -1,10 +1,11 @@ import click -from trading_cli.commands.data import data -from trading_cli.commands.trade import trade + from trading_cli.commands.backtest import backtest +from trading_cli.commands.data import data from trading_cli.commands.portfolio import portfolio -from trading_cli.commands.strategy import strategy from trading_cli.commands.service import service +from trading_cli.commands.strategy import strategy +from trading_cli.commands.trade import trade @click.group() diff --git a/cli/tests/test_cli_strategy.py b/cli/tests/test_cli_strategy.py index cf3057b..75ba4df 100644 --- a/cli/tests/test_cli_strategy.py +++ b/cli/tests/test_cli_strategy.py @@ -1,6 +1,7 @@ """Tests for strategy CLI commands.""" -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + from click.testing import CliRunner from trading_cli.main import cli |
