summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.dockerignore18
-rw-r--r--.env.example34
-rw-r--r--.github/workflows/ci.yml74
-rw-r--r--CLAUDE.md106
-rw-r--r--cli/src/trading_cli/commands/backtest.py41
-rw-r--r--cli/src/trading_cli/commands/data.py20
-rw-r--r--cli/src/trading_cli/commands/portfolio.py18
-rw-r--r--cli/src/trading_cli/commands/service.py1
-rw-r--r--cli/src/trading_cli/main.py7
-rw-r--r--cli/tests/test_cli_strategy.py3
-rw-r--r--docker-compose.yml85
-rw-r--r--docs/superpowers/plans/2026-04-02-platform-upgrade.md1991
-rw-r--r--docs/superpowers/specs/2026-04-02-platform-upgrade-design.md257
-rw-r--r--monitoring/prometheus.yml2
-rw-r--r--monitoring/prometheus/alert_rules.yml29
-rw-r--r--pyproject.toml25
-rwxr-xr-xscripts/backtest_moc.py54
-rwxr-xr-xscripts/stock_screener.py4
-rw-r--r--services/api/Dockerfile15
-rw-r--r--services/api/pyproject.toml6
-rw-r--r--services/api/src/trading_api/dependencies/__init__.py0
-rw-r--r--services/api/src/trading_api/dependencies/auth.py29
-rw-r--r--services/api/src/trading_api/main.py50
-rw-r--r--services/api/src/trading_api/routers/orders.py29
-rw-r--r--services/api/src/trading_api/routers/portfolio.py22
-rw-r--r--services/api/src/trading_api/routers/strategies.py7
-rw-r--r--services/api/tests/test_api.py1
-rw-r--r--services/api/tests/test_orders_router.py6
-rw-r--r--services/api/tests/test_portfolio_router.py6
-rw-r--r--services/backtester/Dockerfile9
-rw-r--r--services/backtester/pyproject.toml2
-rw-r--r--services/backtester/src/backtester/engine.py5
-rw-r--r--services/backtester/src/backtester/main.py6
-rw-r--r--services/backtester/src/backtester/metrics.py2
-rw-r--r--services/backtester/src/backtester/simulator.py19
-rw-r--r--services/backtester/src/backtester/walk_forward.py4
-rw-r--r--services/backtester/tests/test_engine.py9
-rw-r--r--services/backtester/tests/test_metrics.py9
-rw-r--r--services/backtester/tests/test_simulator.py13
-rw-r--r--services/backtester/tests/test_walk_forward.py10
-rw-r--r--services/data-collector/Dockerfile9
-rw-r--r--services/data-collector/src/data_collector/main.py31
-rw-r--r--services/data-collector/tests/test_storage.py9
-rw-r--r--services/news-collector/Dockerfile12
-rw-r--r--services/news-collector/pyproject.toml7
-rw-r--r--services/news-collector/src/news_collector/collectors/fear_greed.py5
-rw-r--r--services/news-collector/src/news_collector/collectors/fed.py6
-rw-r--r--services/news-collector/src/news_collector/collectors/finnhub.py4
-rw-r--r--services/news-collector/src/news_collector/collectors/reddit.py4
-rw-r--r--services/news-collector/src/news_collector/collectors/rss.py6
-rw-r--r--services/news-collector/src/news_collector/collectors/sec_edgar.py10
-rw-r--r--services/news-collector/src/news_collector/collectors/truth_social.py4
-rw-r--r--services/news-collector/src/news_collector/config.py3
-rw-r--r--services/news-collector/src/news_collector/main.py81
-rw-r--r--services/news-collector/tests/test_fear_greed.py2
-rw-r--r--services/news-collector/tests/test_fed.py3
-rw-r--r--services/news-collector/tests/test_finnhub.py2
-rw-r--r--services/news-collector/tests/test_main.py8
-rw-r--r--services/news-collector/tests/test_reddit.py3
-rw-r--r--services/news-collector/tests/test_rss.py2
-rw-r--r--services/news-collector/tests/test_sec_edgar.py8
-rw-r--r--services/news-collector/tests/test_truth_social.py3
-rw-r--r--services/order-executor/Dockerfile9
-rw-r--r--services/order-executor/src/order_executor/executor.py16
-rw-r--r--services/order-executor/src/order_executor/main.py45
-rw-r--r--services/order-executor/src/order_executor/risk_manager.py26
-rw-r--r--services/order-executor/tests/test_executor.py4
-rw-r--r--services/order-executor/tests/test_risk_manager.py2
-rw-r--r--services/portfolio-manager/Dockerfile9
-rw-r--r--services/portfolio-manager/src/portfolio_manager/main.py43
-rw-r--r--services/portfolio-manager/tests/test_portfolio.py3
-rw-r--r--services/portfolio-manager/tests/test_snapshot.py3
-rw-r--r--services/strategy-engine/Dockerfile9
-rw-r--r--services/strategy-engine/pyproject.toml6
-rw-r--r--services/strategy-engine/src/strategy_engine/config.py6
-rw-r--r--services/strategy-engine/src/strategy_engine/engine.py8
-rw-r--r--services/strategy-engine/src/strategy_engine/main.py33
-rw-r--r--services/strategy-engine/src/strategy_engine/plugin_loader.py1
-rw-r--r--services/strategy-engine/src/strategy_engine/stock_selector.py210
-rw-r--r--services/strategy-engine/strategies/base.py5
-rw-r--r--services/strategy-engine/strategies/bollinger_strategy.py2
-rw-r--r--services/strategy-engine/strategies/combined_strategy.py2
-rw-r--r--services/strategy-engine/strategies/ema_crossover_strategy.py2
-rw-r--r--services/strategy-engine/strategies/grid_strategy.py5
-rw-r--r--services/strategy-engine/strategies/indicators/__init__.py16
-rw-r--r--services/strategy-engine/strategies/indicators/momentum.py2
-rw-r--r--services/strategy-engine/strategies/indicators/trend.py2
-rw-r--r--services/strategy-engine/strategies/indicators/volatility.py2
-rw-r--r--services/strategy-engine/strategies/indicators/volume.py2
-rw-r--r--services/strategy-engine/strategies/macd_strategy.py2
-rw-r--r--services/strategy-engine/strategies/moc_strategy.py4
-rw-r--r--services/strategy-engine/strategies/rsi_strategy.py2
-rw-r--r--services/strategy-engine/strategies/volume_profile_strategy.py4
-rw-r--r--services/strategy-engine/strategies/vwap_strategy.py4
-rw-r--r--services/strategy-engine/tests/test_base_filters.py7
-rw-r--r--services/strategy-engine/tests/test_bollinger_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_combined_strategy.py11
-rw-r--r--services/strategy-engine/tests/test_ema_crossover_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_engine.py8
-rw-r--r--services/strategy-engine/tests/test_grid_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_indicators.py9
-rw-r--r--services/strategy-engine/tests/test_macd_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_moc_strategy.py7
-rw-r--r--services/strategy-engine/tests/test_multi_symbol.py10
-rw-r--r--services/strategy-engine/tests/test_plugin_loader.py2
-rw-r--r--services/strategy-engine/tests/test_rsi_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_stock_selector.py37
-rw-r--r--services/strategy-engine/tests/test_strategy_validation.py8
-rw-r--r--services/strategy-engine/tests/test_volume_profile_strategy.py15
-rw-r--r--services/strategy-engine/tests/test_vwap_strategy.py12
-rw-r--r--shared/alembic/versions/001_initial_schema.py10
-rw-r--r--shared/alembic/versions/002_news_sentiment_tables.py10
-rw-r--r--shared/alembic/versions/003_add_missing_indexes.py35
-rw-r--r--shared/alembic/versions/004_add_signal_detail_columns.py25
-rw-r--r--shared/pyproject.toml32
-rw-r--r--shared/src/shared/broker.py12
-rw-r--r--shared/src/shared/config.py47
-rw-r--r--shared/src/shared/db.py51
-rw-r--r--shared/src/shared/events.py20
-rw-r--r--shared/src/shared/healthcheck.py5
-rw-r--r--shared/src/shared/metrics.py2
-rw-r--r--shared/src/shared/models.py29
-rw-r--r--shared/src/shared/notifier.py22
-rw-r--r--shared/src/shared/resilience.py146
-rw-r--r--shared/src/shared/sa_models.py3
-rw-r--r--shared/src/shared/sentiment.py41
-rw-r--r--shared/src/shared/sentiment_models.py5
-rw-r--r--shared/src/shared/shutdown.py30
-rw-r--r--shared/tests/test_alpaca.py4
-rw-r--r--shared/tests/test_broker.py5
-rw-r--r--shared/tests/test_config_validation.py29
-rw-r--r--shared/tests/test_db.py69
-rw-r--r--shared/tests/test_db_news.py13
-rw-r--r--shared/tests/test_events.py10
-rw-r--r--shared/tests/test_models.py20
-rw-r--r--shared/tests/test_news_events.py6
-rw-r--r--shared/tests/test_notifier.py2
-rw-r--r--shared/tests/test_resilience.py203
-rw-r--r--shared/tests/test_sa_models.py49
-rw-r--r--shared/tests/test_sa_news_models.py2
-rw-r--r--shared/tests/test_sentiment.py44
-rw-r--r--shared/tests/test_sentiment_aggregator.py16
-rw-r--r--shared/tests/test_sentiment_models.py14
-rw-r--r--tests/edge_cases/test_empty_data.py5
-rw-r--r--tests/edge_cases/test_extreme_values.py13
-rw-r--r--tests/edge_cases/test_strategy_reset.py29
-rw-r--r--tests/edge_cases/test_zero_volume.py11
-rw-r--r--tests/integration/test_backtest_end_to_end.py7
-rw-r--r--tests/integration/test_order_execution_flow.py5
-rw-r--r--tests/integration/test_portfolio_tracking_flow.py3
-rw-r--r--tests/integration/test_strategy_signal_flow.py11
151 files changed, 4004 insertions, 986 deletions
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..8daace4
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,18 @@
+__pycache__
+*.pyc
+*.pyo
+.git
+.github
+.venv
+.env
+.env.*
+!.env.example
+tests/
+docs/
+*.md
+.ruff_cache
+.pytest_cache
+.mypy_cache
+monitoring/
+scripts/
+cli/
diff --git a/.env.example b/.env.example
index dcaf9a8..2cc65da 100644
--- a/.env.example
+++ b/.env.example
@@ -1,11 +1,23 @@
-# Alpaca API (get keys from https://app.alpaca.markets)
+# === SECRETS (keep secure, do not commit .env) ===
ALPACA_API_KEY=
ALPACA_API_SECRET=
-ALPACA_PAPER=true
-
-REDIS_URL=redis://localhost:6379
+POSTGRES_USER=trading
+POSTGRES_PASSWORD=trading
DATABASE_URL=postgresql+asyncpg://trading:trading@localhost:5432/trading
+REDIS_URL=redis://localhost:6379
+TELEGRAM_BOT_TOKEN=
+FINNHUB_API_KEY=
+ANTHROPIC_API_KEY=
+API_AUTH_TOKEN=
+METRICS_AUTH_TOKEN=
+
+# === CONFIGURATION ===
+ALPACA_PAPER=true
+DRY_RUN=true
+POSTGRES_DB=trading
LOG_LEVEL=INFO
+LOG_FORMAT=json
+HEALTH_PORT=8080
RISK_MAX_POSITION_SIZE=0.1
RISK_STOP_LOSS_PCT=5
RISK_DAILY_LOSS_LIMIT_PCT=10
@@ -13,27 +25,19 @@ RISK_TRAILING_STOP_PCT=0
RISK_MAX_OPEN_POSITIONS=10
RISK_VOLATILITY_LOOKBACK=20
RISK_VOLATILITY_SCALE=false
-DRY_RUN=true
-TELEGRAM_BOT_TOKEN=
TELEGRAM_CHAT_ID=
TELEGRAM_ENABLED=false
-LOG_FORMAT=json
-HEALTH_PORT=8080
-CIRCUIT_BREAKER_THRESHOLD=5
-CIRCUIT_BREAKER_TIMEOUT=60
-METRICS_AUTH_TOKEN=
# News Collector
-FINNHUB_API_KEY=
NEWS_POLL_INTERVAL=300
SENTIMENT_AGGREGATE_INTERVAL=900
# Stock Selector
-SELECTOR_CANDIDATES_TIME=15:00
-SELECTOR_FILTER_TIME=15:15
SELECTOR_FINAL_TIME=15:30
SELECTOR_MAX_PICKS=3
# LLM (for stock selector)
-ANTHROPIC_API_KEY=
ANTHROPIC_MODEL=claude-sonnet-4-20250514
+
+# === API SECURITY ===
+CORS_ORIGINS=http://localhost:3000
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..c07541b
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,74 @@
+name: CI
+
+on:
+ push:
+ branches: [master]
+ pull_request:
+ branches: [master]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ - run: pip install ruff
+ - run: ruff check .
+ - run: ruff format --check .
+
+ test:
+ runs-on: ubuntu-latest
+ services:
+ redis:
+ image: redis:7-alpine
+ ports: [6379:6379]
+ options: >-
+ --health-cmd "redis-cli ping"
+ --health-interval 5s
+ --health-timeout 3s
+ --health-retries 5
+ postgres:
+ image: postgres:16-alpine
+ env:
+ POSTGRES_USER: trading
+ POSTGRES_PASSWORD: trading
+ POSTGRES_DB: trading
+ ports: [5432:5432]
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 5s
+ --health-timeout 3s
+ --health-retries 5
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ - run: |
+ pip install -e shared/[dev]
+ pip install -e services/strategy-engine/[dev]
+ pip install -e services/data-collector/[dev]
+ pip install -e services/order-executor/[dev]
+ pip install -e services/portfolio-manager/[dev]
+ pip install -e services/news-collector/[dev]
+ pip install -e services/api/[dev]
+ pip install -e services/backtester/[dev]
+ pip install pytest-cov
+ - run: pytest -v --cov=shared/src --cov=services --cov-report=xml --cov-report=term-missing
+ env:
+ DATABASE_URL: postgresql+asyncpg://trading:trading@localhost:5432/trading
+ REDIS_URL: redis://localhost:6379
+ - uses: actions/upload-artifact@v4
+ with:
+ name: coverage-report
+ path: coverage.xml
+
+ docker:
+ runs-on: ubuntu-latest
+ needs: [lint, test]
+ if: github.ref == 'refs/heads/master'
+ steps:
+ - uses: actions/checkout@v4
+ - run: docker compose build --quiet
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 0000000..6e33f57
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,106 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## Project Overview
+
+US stock trading platform built as a Python microservices architecture. Uses Alpaca Markets API for market data and order execution. Services communicate via Redis Streams and persist to PostgreSQL.
+
+## Common Commands
+
+```bash
+make infra # Start Redis + Postgres (required before running services/tests)
+make up # Start all services via Docker Compose
+make down # Stop all services
+make test # Run all tests (pytest -v)
+make lint # Lint check (ruff check + format check)
+make format # Auto-fix lint + format
+make migrate # Run DB migrations (alembic upgrade head, from shared/)
+make migrate-new msg="description" # Create new migration
+make ci # Full CI: install deps, lint, test, Docker build
+make e2e # End-to-end tests
+```
+
+Run a single test file: `pytest services/strategy-engine/tests/test_rsi_strategy.py -v`
+
+## Architecture
+
+### Services (each in `services/<name>/`, each has its own Dockerfile)
+
+- **data-collector** (port 8080): Fetches stock bars from Alpaca, publishes `CandleEvent` to Redis stream `candles`
+- **news-collector** (port 8084): Continuously collects news from 7 sources (Finnhub, RSS, SEC EDGAR, Truth Social, Reddit, Fear & Greed, Fed), runs sentiment aggregation every 15 min
+- **strategy-engine** (port 8081): Consumes candle events, runs strategies, publishes `SignalEvent` to stream `signals`. Also runs the stock selector at 15:30 ET daily
+- **order-executor** (port 8082): Consumes signals, runs risk checks, places orders via Alpaca, publishes `OrderEvent` to stream `orders`
+- **portfolio-manager** (port 8083): Tracks positions, PnL, portfolio snapshots
+- **api** (port 8000): FastAPI REST endpoint layer
+- **backtester**: Offline backtesting engine with walk-forward analysis
+
+### Event Flow
+
+```
+Alpaca → data-collector → [candles stream] → strategy-engine → [signals stream] → order-executor → [orders stream] → portfolio-manager
+
+News sources → news-collector → [news stream] → sentiment aggregator → symbol_scores DB
+ ↓
+ stock selector (15:30 ET) → [selected_stocks stream] → MOC strategy → signals
+```
+
+All inter-service events use `shared/src/shared/events.py` (CandleEvent, SignalEvent, OrderEvent, NewsEvent) serialized as JSON over Redis Streams via `shared/src/shared/broker.py` (RedisBroker).
+
+### Shared Library (`shared/`)
+
+Installed as editable package (`pip install -e shared/`). Contains:
+- `models.py` — Pydantic domain models: Candle, Signal, Order, Position, NewsItem, NewsCategory
+- `sentiment_models.py` — SymbolScore, MarketSentiment, SelectedStock, Candidate
+- `sa_models.py` — SQLAlchemy ORM models (CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow, NewsItemRow, SymbolScoreRow, MarketSentimentRow, StockSelectionRow)
+- `broker.py` — RedisBroker (async Redis Streams pub/sub with consumer groups)
+- `db.py` — Database class (async SQLAlchemy 2.0), includes news/sentiment/selection CRUD methods
+- `alpaca.py` — AlpacaClient (async aiohttp client for Alpaca Trading + Market Data APIs)
+- `events.py` — Event types and serialization (CandleEvent, SignalEvent, OrderEvent, NewsEvent)
+- `sentiment.py` — SentimentData (legacy gating) + SentimentAggregator (freshness-weighted composite scoring)
+- `config.py`, `logging.py`, `metrics.py`, `notifier.py` (Telegram), `resilience.py`, `healthcheck.py`
+
+DB migrations live in `shared/alembic/`.
+
+### Strategy System (`services/strategy-engine/strategies/`)
+
+Strategies extend `BaseStrategy` (in `strategies/base.py`) and implement `on_candle()`, `configure()`, `warmup_period`. The plugin loader (`strategy_engine/plugin_loader.py`) auto-discovers `*.py` files in the strategies directory and loads YAML config from `strategies/config/<strategy_name>.yaml`.
+
+BaseStrategy provides optional filters (ADX regime, volume, ATR-based stops) via `_init_filters()` and `_apply_filters()`.
+
+### News-Driven Stock Selector (`services/strategy-engine/src/strategy_engine/stock_selector.py`)
+
+Dynamic stock selection for MOC (Market on Close) trading. Runs daily at 15:30 ET via `strategy-engine`:
+
+1. **Candidate Pool**: Top 20 by sentiment score + LLM-recommended stocks from today's news
+2. **Technical Filter**: RSI 30-70, price > 20 EMA, volume > 50% average
+3. **LLM Final Selection**: Claude picks 2-3 stocks with rationale
+
+Market gating: blocks all trades when Fear & Greed ≤ 20 or VIX > 30 (`risk_off` regime).
+
+### News Collector (`services/news-collector/`)
+
+7 collectors extending `BaseCollector` in `collectors/`:
+- `finnhub.py` (5min), `rss.py` (10min), `reddit.py` (15min), `truth_social.py` (15min), `sec_edgar.py` (30min), `fear_greed.py` (1hr), `fed.py` (1hr)
+- All use VADER (nltk) for sentiment scoring
+- Provider abstraction via `BaseCollector` for future paid API swap (config change only)
+
+Sentiment aggregation (every 15min) computes per-symbol composite scores with freshness decay and category weights (policy 0.3, news 0.3, social 0.2, filing 0.2).
+
+### CLI (`cli/`)
+
+Click-based CLI installed as `trading` command. Depends on the shared library.
+
+## Tech Stack
+
+- Python 3.12+, async throughout (asyncio, aiohttp)
+- Pydantic for models, SQLAlchemy 2.0 async ORM, Alembic for migrations
+- Redis Streams for inter-service messaging
+- PostgreSQL 16 for persistence
+- Ruff for linting/formatting (line-length=100)
+- pytest + pytest-asyncio (asyncio_mode="auto")
+- Docker Compose for deployment; monitoring stack (Grafana/Prometheus/Loki) available via `--profile monitoring`
+
+## Environment
+
+Copy `.env.example` to `.env`. Key vars: `ALPACA_API_KEY`, `ALPACA_API_SECRET`, `ALPACA_PAPER=true`, `DRY_RUN=true`, `DATABASE_URL`, `REDIS_URL`, `FINNHUB_API_KEY`, `ANTHROPIC_API_KEY`. DRY_RUN=true simulates order fills without hitting Alpaca. Stock selector requires `ANTHROPIC_API_KEY` to be set.
diff --git a/cli/src/trading_cli/commands/backtest.py b/cli/src/trading_cli/commands/backtest.py
index 3876f1b..c17c61f 100644
--- a/cli/src/trading_cli/commands/backtest.py
+++ b/cli/src/trading_cli/commands/backtest.py
@@ -35,11 +35,12 @@ def backtest():
def run(strategy, symbol, timeframe, balance, output_format, file_path):
"""Run a backtest for a strategy."""
try:
- from strategy_engine.plugin_loader import load_strategies
from backtester.engine import BacktestEngine
- from backtester.reporter import format_report, export_csv, export_json
- from shared.db import Database
+ from backtester.reporter import export_csv, export_json, format_report
+ from strategy_engine.plugin_loader import load_strategies
+
from shared.config import Settings
+ from shared.db import Database
from shared.models import Candle
except ImportError as e:
click.echo(f"Error: Could not import required modules: {e}", err=True)
@@ -58,7 +59,7 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path):
async def _run():
settings = Settings()
- db = Database(settings.database_url)
+ db = Database(settings.database_url.get_secret_value())
await db.connect()
try:
candle_rows = await db.get_candles(symbol, timeframe, limit=500)
@@ -66,20 +67,19 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path):
click.echo(f"Error: No candles found for {symbol} {timeframe}", err=True)
sys.exit(1)
- candles = []
- for row in reversed(candle_rows): # get_candles returns DESC, we need ASC
- candles.append(
- Candle(
- symbol=row["symbol"],
- timeframe=row["timeframe"],
- open_time=row["open_time"],
- open=row["open"],
- high=row["high"],
- low=row["low"],
- close=row["close"],
- volume=row["volume"],
- )
+ candles = [
+ Candle(
+ symbol=row["symbol"],
+ timeframe=row["timeframe"],
+ open_time=row["open_time"],
+ open=row["open"],
+ high=row["high"],
+ low=row["low"],
+ close=row["close"],
+ volume=row["volume"],
)
+ for row in reversed(candle_rows) # get_candles returns DESC, we need ASC
+ ]
engine = BacktestEngine(strat, Decimal(str(balance)))
result = engine.run(candles)
@@ -111,10 +111,11 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path):
def walk_forward(strategy, symbol, timeframe, balance, windows):
"""Run walk-forward analysis to detect overfitting."""
try:
- from strategy_engine.plugin_loader import load_strategies
from backtester.walk_forward import WalkForwardEngine
- from shared.db import Database
+ from strategy_engine.plugin_loader import load_strategies
+
from shared.config import Settings
+ from shared.db import Database
from shared.models import Candle
except ImportError as e:
click.echo(f"Error: Could not import required modules: {e}", err=True)
@@ -131,7 +132,7 @@ def walk_forward(strategy, symbol, timeframe, balance, windows):
async def _run():
settings = Settings()
- db = Database(settings.database_url)
+ db = Database(settings.database_url.get_secret_value())
await db.connect()
try:
rows = await db.get_candles(symbol, timeframe, limit=2000)
diff --git a/cli/src/trading_cli/commands/data.py b/cli/src/trading_cli/commands/data.py
index 1ecc15f..64639cf 100644
--- a/cli/src/trading_cli/commands/data.py
+++ b/cli/src/trading_cli/commands/data.py
@@ -1,5 +1,6 @@
import asyncio
import sys
+from datetime import UTC
from pathlib import Path
import click
@@ -39,23 +40,23 @@ def history(symbol, timeframe, since, limit):
"""Download historical stock market data for a symbol."""
try:
from shared.alpaca_client import AlpacaClient
- from shared.db import Database
from shared.config import Settings
+ from shared.db import Database
except ImportError as e:
click.echo(f"Error: Could not import required modules: {e}", err=True)
sys.exit(1)
async def _fetch():
- from datetime import datetime, timezone
+ from datetime import datetime
settings = Settings()
- db = Database(settings.database_url)
+ db = Database(settings.database_url.get_secret_value())
await db.connect()
start = None
if since:
try:
- start = datetime.fromisoformat(since).replace(tzinfo=timezone.utc)
+ start = datetime.fromisoformat(since).replace(tzinfo=UTC)
except ValueError:
click.echo(
f"Error: Invalid date format '{since}'. Use ISO format (e.g. 2024-01-01).",
@@ -64,8 +65,8 @@ def history(symbol, timeframe, since, limit):
sys.exit(1)
client = AlpacaClient(
- api_key=settings.alpaca_api_key,
- api_secret=settings.alpaca_api_secret,
+ api_key=settings.alpaca_api_key.get_secret_value(),
+ api_secret=settings.alpaca_api_secret.get_secret_value(),
base_url=getattr(settings, "alpaca_base_url", "https://paper-api.alpaca.markets"),
)
@@ -97,17 +98,18 @@ def history(symbol, timeframe, since, limit):
def list_():
"""List available data streams and symbols."""
try:
- from shared.db import Database
+ from sqlalchemy import func, select
+
from shared.config import Settings
+ from shared.db import Database
from shared.sa_models import CandleRow
- from sqlalchemy import select, func
except ImportError as e:
click.echo(f"Error: Could not import required modules: {e}", err=True)
sys.exit(1)
async def _list():
settings = Settings()
- db = Database(settings.database_url)
+ db = Database(settings.database_url.get_secret_value())
await db.connect()
try:
stmt = (
diff --git a/cli/src/trading_cli/commands/portfolio.py b/cli/src/trading_cli/commands/portfolio.py
index ad9a6b4..fd3ebd6 100644
--- a/cli/src/trading_cli/commands/portfolio.py
+++ b/cli/src/trading_cli/commands/portfolio.py
@@ -1,6 +1,6 @@
import asyncio
import sys
-from datetime import datetime, timedelta, timezone
+from datetime import UTC, datetime, timedelta
import click
from rich.console import Console
@@ -17,17 +17,18 @@ def portfolio():
def show():
"""Show the current portfolio holdings and balances."""
try:
- from shared.db import Database
+ from sqlalchemy import select
+
from shared.config import Settings
+ from shared.db import Database
from shared.sa_models import PositionRow
- from sqlalchemy import select
except ImportError as e:
click.echo(f"Error: Could not import required modules: {e}", err=True)
sys.exit(1)
async def _show():
settings = Settings()
- db = Database(settings.database_url)
+ db = Database(settings.database_url.get_secret_value())
await db.connect()
try:
async with db.get_session() as session:
@@ -71,20 +72,21 @@ def show():
def history(days):
"""Show PnL history for the portfolio."""
try:
- from shared.db import Database
+ from sqlalchemy import select
+
from shared.config import Settings
+ from shared.db import Database
from shared.sa_models import PortfolioSnapshotRow
- from sqlalchemy import select
except ImportError as e:
click.echo(f"Error: Could not import required modules: {e}", err=True)
sys.exit(1)
async def _history():
settings = Settings()
- db = Database(settings.database_url)
+ db = Database(settings.database_url.get_secret_value())
await db.connect()
try:
- since = datetime.now(timezone.utc) - timedelta(days=days)
+ since = datetime.now(UTC) - timedelta(days=days)
stmt = (
select(PortfolioSnapshotRow)
.where(PortfolioSnapshotRow.snapshot_at >= since)
diff --git a/cli/src/trading_cli/commands/service.py b/cli/src/trading_cli/commands/service.py
index d01eaae..6d02f14 100644
--- a/cli/src/trading_cli/commands/service.py
+++ b/cli/src/trading_cli/commands/service.py
@@ -1,4 +1,5 @@
import subprocess
+
import click
diff --git a/cli/src/trading_cli/main.py b/cli/src/trading_cli/main.py
index 1129bdd..0ed2307 100644
--- a/cli/src/trading_cli/main.py
+++ b/cli/src/trading_cli/main.py
@@ -1,10 +1,11 @@
import click
-from trading_cli.commands.data import data
-from trading_cli.commands.trade import trade
+
from trading_cli.commands.backtest import backtest
+from trading_cli.commands.data import data
from trading_cli.commands.portfolio import portfolio
-from trading_cli.commands.strategy import strategy
from trading_cli.commands.service import service
+from trading_cli.commands.strategy import strategy
+from trading_cli.commands.trade import trade
@click.group()
diff --git a/cli/tests/test_cli_strategy.py b/cli/tests/test_cli_strategy.py
index cf3057b..75ba4df 100644
--- a/cli/tests/test_cli_strategy.py
+++ b/cli/tests/test_cli_strategy.py
@@ -1,6 +1,7 @@
"""Tests for strategy CLI commands."""
-from unittest.mock import patch, MagicMock
+from unittest.mock import MagicMock, patch
+
from click.testing import CliRunner
from trading_cli.main import cli
diff --git a/docker-compose.yml b/docker-compose.yml
index 63630ff..60462ec 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -10,15 +10,21 @@ services:
interval: 5s
timeout: 3s
retries: 5
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 256M
+ cpus: '0.5'
postgres:
image: postgres:16-alpine
ports:
- "5432:5432"
environment:
- POSTGRES_USER: trading
- POSTGRES_PASSWORD: trading
- POSTGRES_DB: trading
+ POSTGRES_USER: ${POSTGRES_USER:-trading}
+ POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading}
+ POSTGRES_DB: ${POSTGRES_DB:-trading}
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
@@ -26,6 +32,12 @@ services:
interval: 5s
timeout: 3s
retries: 5
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 256M
+ cpus: '0.5'
data-collector:
build:
@@ -45,6 +57,12 @@ services:
timeout: 5s
retries: 3
restart: unless-stopped
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
strategy-engine:
build:
@@ -64,6 +82,12 @@ services:
timeout: 5s
retries: 3
restart: unless-stopped
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 1G
+ cpus: '1.0'
order-executor:
build:
@@ -83,6 +107,12 @@ services:
timeout: 5s
retries: 3
restart: unless-stopped
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
portfolio-manager:
build:
@@ -102,6 +132,12 @@ services:
timeout: 5s
retries: 3
restart: unless-stopped
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
api:
build:
@@ -121,6 +157,12 @@ services:
timeout: 5s
retries: 3
restart: unless-stopped
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
news-collector:
build:
@@ -140,6 +182,12 @@ services:
timeout: 5s
retries: 3
restart: unless-stopped
+ networks: [internal]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
loki:
image: grafana/loki:latest
@@ -150,6 +198,12 @@ services:
- ./monitoring/loki/loki-config.yaml:/etc/loki/local-config.yaml
- loki_data:/loki
command: -config.file=/etc/loki/local-config.yaml
+ networks: [monitoring]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
promtail:
image: grafana/promtail:latest
@@ -160,6 +214,12 @@ services:
command: -config.file=/etc/promtail/config.yaml
depends_on:
- loki
+ networks: [monitoring]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
prometheus:
image: prom/prometheus:latest
@@ -168,11 +228,18 @@ services:
- "9090:9090"
volumes:
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
+ - ./monitoring/prometheus/alert_rules.yml:/etc/prometheus/alert_rules.yml
depends_on:
- data-collector
- strategy-engine
- order-executor
- portfolio-manager
+ networks: [internal, monitoring]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
grafana:
image: grafana/grafana:latest
@@ -187,8 +254,20 @@ services:
- ./monitoring/grafana/dashboards:/var/lib/grafana/dashboards
depends_on:
- prometheus
+ networks: [internal, monitoring]
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
volumes:
redis_data:
postgres_data:
loki_data:
+
+networks:
+ internal:
+ driver: bridge
+ monitoring:
+ driver: bridge
diff --git a/docs/superpowers/plans/2026-04-02-platform-upgrade.md b/docs/superpowers/plans/2026-04-02-platform-upgrade.md
new file mode 100644
index 0000000..c28d287
--- /dev/null
+++ b/docs/superpowers/plans/2026-04-02-platform-upgrade.md
@@ -0,0 +1,1991 @@
+# Platform Upgrade Implementation Plan
+
+> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
+
+**Goal:** Upgrade the trading platform across 5 phases: shared library hardening, infrastructure improvements, service-level fixes, API security, and operational maturity.
+
+**Architecture:** Bottom-up approach — harden the shared library first (resilience, DB pooling, Redis resilience, config validation), then improve infrastructure (Docker, DB indexes), then fix all services (graceful shutdown, exception handling), then add API security (auth, CORS, rate limiting), and finally improve operations (CI/CD, linting, alerting).
+
+**Tech Stack:** Python 3.12, asyncio, tenacity, SQLAlchemy 2.0 async, Redis Streams, FastAPI, slowapi, Ruff, GitHub Actions, Prometheus Alertmanager
+
+---
+
+## File Structure
+
+### New Files
+- `shared/src/shared/resilience.py` — Retry decorator, circuit breaker, timeout wrapper
+- `shared/tests/test_resilience.py` — Tests for resilience module
+- `shared/alembic/versions/003_add_missing_indexes.py` — DB index migration
+- `.dockerignore` — Docker build exclusions
+- `services/api/src/trading_api/dependencies/auth.py` — Bearer token auth dependency
+- `.github/workflows/ci.yml` — GitHub Actions CI pipeline
+- `monitoring/prometheus/alert_rules.yml` — Prometheus alerting rules
+
+### Modified Files
+- `shared/src/shared/db.py` — Add connection pool config
+- `shared/src/shared/broker.py` — Add Redis resilience
+- `shared/src/shared/config.py` — Add validators, SecretStr, new fields
+- `shared/pyproject.toml` — Pin deps, add tenacity
+- `pyproject.toml` — Enhanced ruff rules, pytest-cov
+- `services/strategy-engine/src/strategy_engine/stock_selector.py` — Fix bug, deduplicate, session reuse
+- `services/*/src/*/main.py` — Signal handlers, exception specialization (all 6 services)
+- `services/*/Dockerfile` — Multi-stage builds, non-root user (all 7 Dockerfiles)
+- `services/api/pyproject.toml` — Add slowapi
+- `services/api/src/trading_api/main.py` — CORS, auth, rate limiting
+- `services/api/src/trading_api/routers/*.py` — Input validation, response models
+- `docker-compose.yml` — Remove hardcoded creds, add resource limits, networks
+- `.env.example` — Add new fields, mark secrets
+- `monitoring/prometheus.yml` — Reference alert rules
+
+---
+
+## Phase 1: Shared Library Hardening
+
+### Task 1: Implement Resilience Module
+
+**Files:**
+- Create: `shared/src/shared/resilience.py`
+- Create: `shared/tests/test_resilience.py`
+- Modify: `shared/pyproject.toml:6-18`
+
+- [ ] **Step 1: Add tenacity dependency to shared/pyproject.toml**
+
+In `shared/pyproject.toml`, add `tenacity` to the dependencies list:
+
+```python
+dependencies = [
+ "pydantic>=2.8,<3",
+ "pydantic-settings>=2.0,<3",
+ "redis>=5.0,<6",
+ "asyncpg>=0.29,<1",
+ "sqlalchemy[asyncio]>=2.0,<3",
+ "alembic>=1.13,<2",
+ "structlog>=24.0,<25",
+ "prometheus-client>=0.20,<1",
+ "pyyaml>=6.0,<7",
+ "aiohttp>=3.9,<4",
+ "rich>=13.0,<14",
+ "tenacity>=8.2,<10",
+]
+```
+
+Note: This also pins all existing dependencies with upper bounds.
+
+- [ ] **Step 2: Write failing tests for retry_async**
+
+Create `shared/tests/test_resilience.py`:
+
+```python
+"""Tests for the resilience module."""
+
+import asyncio
+
+import pytest
+
+from shared.resilience import retry_async, CircuitBreaker, async_timeout
+
+
+class TestRetryAsync:
+ async def test_succeeds_without_retry(self):
+ call_count = 0
+
+ @retry_async(max_retries=3)
+ async def succeed():
+ nonlocal call_count
+ call_count += 1
+ return "ok"
+
+ result = await succeed()
+ assert result == "ok"
+ assert call_count == 1
+
+ async def test_retries_on_failure_then_succeeds(self):
+ call_count = 0
+
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fail_twice():
+ nonlocal call_count
+ call_count += 1
+ if call_count < 3:
+ raise ConnectionError("fail")
+ return "ok"
+
+ result = await fail_twice()
+ assert result == "ok"
+ assert call_count == 3
+
+ async def test_raises_after_max_retries(self):
+ @retry_async(max_retries=2, base_delay=0.01)
+ async def always_fail():
+ raise ConnectionError("fail")
+
+ with pytest.raises(ConnectionError):
+ await always_fail()
+
+ async def test_no_retry_on_excluded_exception(self):
+ call_count = 0
+
+ @retry_async(max_retries=3, base_delay=0.01, exclude=(ValueError,))
+ async def raise_value_error():
+ nonlocal call_count
+ call_count += 1
+ raise ValueError("bad input")
+
+ with pytest.raises(ValueError):
+ await raise_value_error()
+ assert call_count == 1
+
+
+class TestCircuitBreaker:
+ async def test_closed_allows_calls(self):
+ cb = CircuitBreaker(failure_threshold=3, cooldown=0.1)
+
+ async def succeed():
+ return "ok"
+
+ result = await cb.call(succeed)
+ assert result == "ok"
+
+ async def test_opens_after_threshold(self):
+ cb = CircuitBreaker(failure_threshold=2, cooldown=60)
+
+ async def fail():
+ raise ConnectionError("fail")
+
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(fail)
+
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+ async def test_half_open_after_cooldown(self):
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ call_count = 0
+
+ async def fail_then_succeed():
+ nonlocal call_count
+ call_count += 1
+ if call_count <= 2:
+ raise ConnectionError("fail")
+ return "recovered"
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(fail_then_succeed)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.1)
+
+ # Should allow one call (half-open)
+ result = await cb.call(fail_then_succeed)
+ assert result == "recovered"
+
+
+class TestAsyncTimeout:
+ async def test_completes_within_timeout(self):
+ async with async_timeout(1.0):
+ await asyncio.sleep(0.01)
+
+ async def test_raises_on_timeout(self):
+ with pytest.raises(asyncio.TimeoutError):
+ async with async_timeout(0.01):
+ await asyncio.sleep(1.0)
+```
+
+- [ ] **Step 3: Run tests to verify they fail**
+
+Run: `pytest shared/tests/test_resilience.py -v`
+Expected: FAIL with `ImportError: cannot import name 'retry_async' from 'shared.resilience'`
+
+- [ ] **Step 4: Implement resilience module**
+
+Write `shared/src/shared/resilience.py`:
+
+```python
+"""Resilience utilities: retry, circuit breaker, timeout."""
+
+import asyncio
+import functools
+import logging
+import time
+from contextlib import asynccontextmanager
+
+logger = logging.getLogger(__name__)
+
+
+def retry_async(
+ max_retries: int = 3,
+ base_delay: float = 1.0,
+ max_delay: float = 30.0,
+ exclude: tuple[type[Exception], ...] = (),
+):
+ """Decorator for async functions with exponential backoff + jitter.
+
+ Args:
+ max_retries: Maximum number of retry attempts.
+ base_delay: Initial delay in seconds between retries.
+ max_delay: Maximum delay cap in seconds.
+ exclude: Exception types that should NOT be retried.
+ """
+
+ def decorator(func):
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ last_exc = None
+ for attempt in range(max_retries + 1):
+ try:
+ return await func(*args, **kwargs)
+ except exclude:
+ raise
+ except Exception as exc:
+ last_exc = exc
+ if attempt == max_retries:
+ raise
+ delay = min(base_delay * (2**attempt), max_delay)
+ # Add jitter: 50-100% of delay
+ import random
+
+ delay = delay * (0.5 + random.random() * 0.5)
+ logger.warning(
+ "retry attempt=%d/%d delay=%.2fs error=%s func=%s",
+ attempt + 1,
+ max_retries,
+ delay,
+ str(exc),
+ func.__name__,
+ )
+ await asyncio.sleep(delay)
+ raise last_exc # Should not reach here, but just in case
+
+ return wrapper
+
+ return decorator
+
+
+class CircuitBreaker:
+ """Circuit breaker: opens after consecutive failures, auto-recovers after cooldown."""
+
+ def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None:
+ self._failure_threshold = failure_threshold
+ self._cooldown = cooldown
+ self._failure_count = 0
+ self._last_failure_time: float = 0
+ self._state = "closed" # closed, open, half_open
+
+ async def call(self, func, *args, **kwargs):
+ if self._state == "open":
+ if time.monotonic() - self._last_failure_time >= self._cooldown:
+ self._state = "half_open"
+ else:
+ raise RuntimeError("Circuit breaker is open")
+
+ try:
+ result = await func(*args, **kwargs)
+ self._failure_count = 0
+ self._state = "closed"
+ return result
+ except Exception:
+ self._failure_count += 1
+ self._last_failure_time = time.monotonic()
+ if self._failure_count >= self._failure_threshold:
+ self._state = "open"
+ logger.error(
+ "circuit_breaker_opened failures=%d cooldown=%.0fs",
+ self._failure_count,
+ self._cooldown,
+ )
+ raise
+
+
+@asynccontextmanager
+async def async_timeout(seconds: float):
+ """Async context manager that raises TimeoutError after given seconds."""
+ try:
+ async with asyncio.timeout(seconds):
+ yield
+ except TimeoutError:
+ raise asyncio.TimeoutError(f"Operation timed out after {seconds}s")
+```
+
+- [ ] **Step 5: Run tests to verify they pass**
+
+Run: `pytest shared/tests/test_resilience.py -v`
+Expected: All 8 tests PASS
+
+- [ ] **Step 6: Commit**
+
+```bash
+git add shared/src/shared/resilience.py shared/tests/test_resilience.py shared/pyproject.toml
+git commit -m "feat: implement resilience module with retry, circuit breaker, timeout"
+```
+
+---
+
+### Task 2: Add DB Connection Pooling
+
+**Files:**
+- Modify: `shared/src/shared/db.py:39-44`
+- Modify: `shared/src/shared/config.py:10-11`
+- Modify: `shared/tests/test_db.py` (add pool config test)
+
+- [ ] **Step 1: Write failing test for pool config**
+
+Add to `shared/tests/test_db.py`:
+
+```python
+async def test_connect_configures_pool(tmp_path):
+ """Engine should be created with pool configuration."""
+ db = Database("sqlite+aiosqlite:///:memory:")
+ await db.connect()
+ engine = db._engine
+ pool = engine.pool
+ # aiosqlite uses StaticPool so we just verify connect works
+ assert engine is not None
+ await db.close()
+```
+
+- [ ] **Step 2: Add pool settings to config.py**
+
+In `shared/src/shared/config.py`, add after line 11 (`database_url`):
+
+```python
+ db_pool_size: int = 20
+ db_max_overflow: int = 10
+ db_pool_recycle: int = 3600
+```
+
+- [ ] **Step 3: Update Database.connect() with pool parameters**
+
+In `shared/src/shared/db.py`, replace line 41:
+
+```python
+ self._engine = create_async_engine(self._database_url)
+```
+
+with:
+
+```python
+ self._engine = create_async_engine(
+ self._database_url,
+ pool_pre_ping=True,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ )
+```
+
+Update the `connect` method signature to accept pool params:
+
+```python
+ async def connect(
+ self,
+ pool_size: int = 20,
+ max_overflow: int = 10,
+ pool_recycle: int = 3600,
+ ) -> None:
+ """Create the async engine, session factory, and all tables."""
+ if self._database_url.startswith("sqlite"):
+ self._engine = create_async_engine(self._database_url)
+ else:
+ self._engine = create_async_engine(
+ self._database_url,
+ pool_pre_ping=True,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ )
+ self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
+ async with self._engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+```
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest shared/tests/test_db.py -v`
+Expected: PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add shared/src/shared/db.py shared/src/shared/config.py shared/tests/test_db.py
+git commit -m "feat: add DB connection pooling with configurable pool_size, overflow, recycle"
+```
+
+---
+
+### Task 3: Add Redis Resilience
+
+**Files:**
+- Modify: `shared/src/shared/broker.py:1-13,15-18,102-104`
+- Create: `shared/tests/test_broker_resilience.py`
+
+- [ ] **Step 1: Write failing tests for Redis resilience**
+
+Create `shared/tests/test_broker_resilience.py`:
+
+```python
+"""Tests for Redis broker resilience features."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from shared.broker import RedisBroker
+
+
+class TestBrokerResilience:
+ async def test_publish_retries_on_connection_error(self):
+ broker = RedisBroker.__new__(RedisBroker)
+ mock_redis = AsyncMock()
+ call_count = 0
+
+ async def xadd_failing(*args, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count < 3:
+ raise ConnectionError("Redis connection lost")
+ return "msg-id"
+
+ mock_redis.xadd = xadd_failing
+ broker._redis = mock_redis
+
+ await broker.publish("test-stream", {"key": "value"})
+ assert call_count == 3
+
+ async def test_ping_retries_on_timeout(self):
+ broker = RedisBroker.__new__(RedisBroker)
+ mock_redis = AsyncMock()
+ call_count = 0
+
+ async def ping_failing():
+ nonlocal call_count
+ call_count += 1
+ if call_count < 2:
+ raise TimeoutError("timeout")
+ return True
+
+ mock_redis.ping = ping_failing
+ broker._redis = mock_redis
+
+ result = await broker.ping()
+ assert result is True
+ assert call_count == 2
+```
+
+- [ ] **Step 2: Run tests to verify they fail**
+
+Run: `pytest shared/tests/test_broker_resilience.py -v`
+Expected: FAIL (publish doesn't retry)
+
+- [ ] **Step 3: Add resilience to broker.py**
+
+Replace `shared/src/shared/broker.py`:
+
+```python
+"""Redis Streams broker for the trading platform."""
+
+import json
+import logging
+from typing import Any
+
+import redis.asyncio
+
+from shared.resilience import retry_async
+
+logger = logging.getLogger(__name__)
+
+
+class RedisBroker:
+ """Async Redis Streams broker for publishing and reading events."""
+
+ def __init__(self, redis_url: str) -> None:
+ self._redis = redis.asyncio.from_url(
+ redis_url,
+ socket_keepalive=True,
+ health_check_interval=30,
+ retry_on_timeout=True,
+ )
+
+ @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,))
+ async def publish(self, stream: str, data: dict[str, Any]) -> None:
+ """Publish a message to a Redis stream."""
+ payload = json.dumps(data)
+ await self._redis.xadd(stream, {"payload": payload})
+
+ async def ensure_group(self, stream: str, group: str) -> None:
+ """Create a consumer group if it doesn't exist."""
+ try:
+ await self._redis.xgroup_create(stream, group, id="0", mkstream=True)
+ except redis.ResponseError as e:
+ if "BUSYGROUP" not in str(e):
+ raise
+
+ @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,))
+ async def read_group(
+ self,
+ stream: str,
+ group: str,
+ consumer: str,
+ count: int = 10,
+ block: int = 0,
+ ) -> list[tuple[str, dict[str, Any]]]:
+ """Read messages from a consumer group. Returns list of (message_id, data)."""
+ results = await self._redis.xreadgroup(
+ group, consumer, {stream: ">"}, count=count, block=block
+ )
+ messages = []
+ if results:
+ for _stream, entries in results:
+ for msg_id, fields in entries:
+ payload = fields.get(b"payload") or fields.get("payload")
+ if payload:
+ if isinstance(payload, bytes):
+ payload = payload.decode()
+ if isinstance(msg_id, bytes):
+ msg_id = msg_id.decode()
+ messages.append((msg_id, json.loads(payload)))
+ return messages
+
+ async def ack(self, stream: str, group: str, *msg_ids: str) -> None:
+ """Acknowledge messages in a consumer group."""
+ if msg_ids:
+ await self._redis.xack(stream, group, *msg_ids)
+
+ async def read_pending(
+ self,
+ stream: str,
+ group: str,
+ consumer: str,
+ count: int = 10,
+ ) -> list[tuple[str, dict[str, Any]]]:
+ """Read pending (unacknowledged) messages for this consumer."""
+ results = await self._redis.xreadgroup(group, consumer, {stream: "0"}, count=count)
+ messages = []
+ if results:
+ for _stream, entries in results:
+ for msg_id, fields in entries:
+ if not fields:
+ continue
+ payload = fields.get(b"payload") or fields.get("payload")
+ if payload:
+ if isinstance(payload, bytes):
+ payload = payload.decode()
+ if isinstance(msg_id, bytes):
+ msg_id = msg_id.decode()
+ messages.append((msg_id, json.loads(payload)))
+ return messages
+
+ async def read(
+ self,
+ stream: str,
+ last_id: str = "$",
+ count: int = 10,
+ block: int = 0,
+ ) -> list[dict[str, Any]]:
+ """Read messages (original method, kept for backward compatibility)."""
+ results = await self._redis.xread({stream: last_id}, count=count, block=block)
+ messages = []
+ if results:
+ for _stream, entries in results:
+ for _msg_id, fields in entries:
+ payload = fields.get(b"payload") or fields.get("payload")
+ if payload:
+ if isinstance(payload, bytes):
+ payload = payload.decode()
+ messages.append(json.loads(payload))
+ return messages
+
+ @retry_async(max_retries=2, base_delay=0.5)
+ async def ping(self) -> bool:
+ """Ping the Redis server; return True if reachable."""
+ return await self._redis.ping()
+
+ async def close(self) -> None:
+ """Close the Redis connection."""
+ await self._redis.aclose()
+```
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest shared/tests/test_broker_resilience.py -v`
+Expected: PASS
+
+Run: `pytest shared/tests/test_broker.py -v`
+Expected: PASS (existing tests still work)
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add shared/src/shared/broker.py shared/tests/test_broker_resilience.py
+git commit -m "feat: add retry and resilience to Redis broker with keepalive"
+```
+
+---
+
+### Task 4: Config Validation & SecretStr
+
+**Files:**
+- Modify: `shared/src/shared/config.py`
+- Create: `shared/tests/test_config_validation.py`
+
+- [ ] **Step 1: Write failing tests for config validation**
+
+Create `shared/tests/test_config_validation.py`:
+
+```python
+"""Tests for config validation."""
+
+import pytest
+from pydantic import ValidationError
+
+from shared.config import Settings
+
+
+class TestConfigValidation:
+ def test_valid_defaults(self):
+ settings = Settings()
+ assert settings.risk_max_position_size == 0.1
+
+ def test_invalid_position_size(self):
+ with pytest.raises(ValidationError, match="risk_max_position_size"):
+ Settings(risk_max_position_size=-0.1)
+
+ def test_invalid_health_port(self):
+ with pytest.raises(ValidationError, match="health_port"):
+ Settings(health_port=80)
+
+ def test_invalid_log_level(self):
+ with pytest.raises(ValidationError, match="log_level"):
+ Settings(log_level="INVALID")
+
+ def test_secret_fields_masked(self):
+ settings = Settings(alpaca_api_key="my-secret-key")
+ assert "my-secret-key" not in repr(settings)
+ assert settings.alpaca_api_key.get_secret_value() == "my-secret-key"
+```
+
+- [ ] **Step 2: Run tests to verify they fail**
+
+Run: `pytest shared/tests/test_config_validation.py -v`
+Expected: FAIL
+
+- [ ] **Step 3: Update config.py with validators and SecretStr**
+
+Replace `shared/src/shared/config.py`:
+
+```python
+"""Shared configuration settings for the trading platform."""
+
+from pydantic import SecretStr, field_validator
+from pydantic_settings import BaseSettings
+
+
+class Settings(BaseSettings):
+ # Alpaca
+ alpaca_api_key: SecretStr = SecretStr("")
+ alpaca_api_secret: SecretStr = SecretStr("")
+ alpaca_paper: bool = True
+ # Infrastructure
+ redis_url: SecretStr = SecretStr("redis://localhost:6379")
+ database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading")
+ # DB pool
+ db_pool_size: int = 20
+ db_max_overflow: int = 10
+ db_pool_recycle: int = 3600
+ # Logging
+ log_level: str = "INFO"
+ log_format: str = "json"
+ # Health
+ health_port: int = 8080
+ metrics_auth_token: str = ""
+ # Risk
+ risk_max_position_size: float = 0.1
+ risk_stop_loss_pct: float = 5.0
+ risk_daily_loss_limit_pct: float = 10.0
+ risk_trailing_stop_pct: float = 0.0
+ risk_max_open_positions: int = 10
+ risk_volatility_lookback: int = 20
+ risk_volatility_scale: bool = False
+ risk_max_portfolio_exposure: float = 0.8
+ risk_max_correlated_exposure: float = 0.5
+ risk_correlation_threshold: float = 0.7
+ risk_var_confidence: float = 0.95
+ risk_var_limit_pct: float = 5.0
+ risk_drawdown_reduction_threshold: float = 0.1
+ risk_drawdown_halt_threshold: float = 0.2
+ risk_max_consecutive_losses: int = 5
+ risk_loss_pause_minutes: int = 60
+ dry_run: bool = True
+ # Telegram
+ telegram_bot_token: SecretStr = SecretStr("")
+ telegram_chat_id: str = ""
+ telegram_enabled: bool = False
+ # News
+ finnhub_api_key: SecretStr = SecretStr("")
+ news_poll_interval: int = 300
+ sentiment_aggregate_interval: int = 900
+ # Stock selector
+ selector_final_time: str = "15:30"
+ selector_max_picks: int = 3
+ # LLM
+ anthropic_api_key: SecretStr = SecretStr("")
+ anthropic_model: str = "claude-sonnet-4-20250514"
+ # API security
+ api_auth_token: SecretStr = SecretStr("")
+ cors_origins: str = "http://localhost:3000"
+
+ model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}
+
+ @field_validator("risk_max_position_size")
+ @classmethod
+ def validate_position_size(cls, v: float) -> float:
+ if v <= 0 or v > 1:
+ raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)")
+ return v
+
+ @field_validator("health_port")
+ @classmethod
+ def validate_health_port(cls, v: int) -> int:
+ if v < 1024 or v > 65535:
+ raise ValueError("health_port must be between 1024 and 65535")
+ return v
+
+ @field_validator("log_level")
+ @classmethod
+ def validate_log_level(cls, v: str) -> str:
+ valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
+ if v.upper() not in valid:
+ raise ValueError(f"log_level must be one of {valid}")
+ return v.upper()
+```
+
+- [ ] **Step 4: Update all consumers to use .get_secret_value()**
+
+Every place that reads `settings.alpaca_api_key` etc. must now call `.get_secret_value()`. Key files to update:
+
+**`shared/src/shared/alpaca.py`** — where AlpacaClient is instantiated (in each service main.py), change:
+```python
+# Before:
+alpaca = AlpacaClient(cfg.alpaca_api_key, cfg.alpaca_api_secret, paper=cfg.alpaca_paper)
+# After:
+alpaca = AlpacaClient(
+ cfg.alpaca_api_key.get_secret_value(),
+ cfg.alpaca_api_secret.get_secret_value(),
+ paper=cfg.alpaca_paper,
+)
+```
+
+**Each service main.py** — where `Database(cfg.database_url)` and `RedisBroker(cfg.redis_url)` are called:
+```python
+# Before:
+db = Database(cfg.database_url)
+broker = RedisBroker(cfg.redis_url)
+# After:
+db = Database(cfg.database_url.get_secret_value())
+broker = RedisBroker(cfg.redis_url.get_secret_value())
+```
+
+**`shared/src/shared/notifier.py`** — where telegram_bot_token is used:
+```python
+# Change token access to .get_secret_value()
+```
+
+**`services/strategy-engine/src/strategy_engine/main.py`** — where anthropic_api_key is passed:
+```python
+# Before:
+anthropic_api_key=cfg.anthropic_api_key,
+# After:
+anthropic_api_key=cfg.anthropic_api_key.get_secret_value(),
+```
+
+**`services/news-collector/src/news_collector/main.py`** — where finnhub_api_key is used:
+```python
+# Before:
+cfg.finnhub_api_key
+# After:
+cfg.finnhub_api_key.get_secret_value()
+```
+
+- [ ] **Step 5: Run all tests**
+
+Run: `pytest shared/tests/test_config_validation.py -v`
+Expected: PASS
+
+Run: `pytest -v`
+Expected: All tests PASS (no regressions from SecretStr changes)
+
+- [ ] **Step 6: Commit**
+
+```bash
+git add shared/src/shared/config.py shared/tests/test_config_validation.py
+git add services/*/src/*/main.py shared/src/shared/notifier.py
+git commit -m "feat: add config validation, SecretStr for secrets, API security fields"
+```
+
+---
+
+### Task 5: Pin All Dependencies
+
+**Files:**
+- Modify: `shared/pyproject.toml` (already done in Task 1)
+- Modify: `services/strategy-engine/pyproject.toml`
+- Modify: `services/backtester/pyproject.toml`
+- Modify: `services/api/pyproject.toml`
+- Modify: `services/news-collector/pyproject.toml`
+- Modify: `services/data-collector/pyproject.toml`
+- Modify: `services/order-executor/pyproject.toml`
+- Modify: `services/portfolio-manager/pyproject.toml`
+
+- [ ] **Step 1: Pin service dependencies**
+
+`services/strategy-engine/pyproject.toml`:
+```toml
+dependencies = [
+ "pandas>=2.1,<3",
+ "numpy>=1.26,<3",
+ "trading-shared",
+]
+```
+
+`services/backtester/pyproject.toml`:
+```toml
+dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "rich>=13.0,<14", "trading-shared"]
+```
+
+`services/api/pyproject.toml`:
+```toml
+dependencies = [
+ "fastapi>=0.110,<1",
+ "uvicorn>=0.27,<1",
+ "slowapi>=0.1.9,<1",
+ "trading-shared",
+]
+```
+
+`services/news-collector/pyproject.toml`:
+```toml
+dependencies = [
+ "trading-shared",
+ "feedparser>=6.0,<7",
+ "nltk>=3.8,<4",
+ "aiohttp>=3.9,<4",
+]
+```
+
+`shared/pyproject.toml` optional deps:
+```toml
+[project.optional-dependencies]
+dev = [
+ "pytest>=8.0,<9",
+ "pytest-asyncio>=0.23,<1",
+ "ruff>=0.4,<1",
+]
+claude = [
+ "anthropic>=0.40,<1",
+]
+```
+
+- [ ] **Step 2: Verify installation works**
+
+Run: `pip install -e shared/ && pip install -e services/strategy-engine/ && pip install -e services/api/`
+Expected: No errors
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add shared/pyproject.toml services/*/pyproject.toml
+git commit -m "chore: pin all dependencies with upper bounds"
+```
+
+---
+
+## Phase 2: Infrastructure Hardening
+
+### Task 6: Docker Secrets & Environment Cleanup
+
+**Files:**
+- Modify: `docker-compose.yml:17-21`
+- Modify: `.env.example`
+
+- [ ] **Step 1: Replace hardcoded Postgres credentials in docker-compose.yml**
+
+In `docker-compose.yml`, replace the postgres service environment:
+
+```yaml
+ postgres:
+ image: postgres:16-alpine
+ environment:
+ POSTGRES_USER: ${POSTGRES_USER:-trading}
+ POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading}
+ POSTGRES_DB: ${POSTGRES_DB:-trading}
+```
+
+- [ ] **Step 2: Update .env.example with secret annotations**
+
+Add to `.env.example`:
+
+```bash
+# === SECRETS (keep secure, do not commit .env) ===
+ALPACA_API_KEY=
+ALPACA_API_SECRET=
+DATABASE_URL=postgresql+asyncpg://trading:trading@localhost:5432/trading
+REDIS_URL=redis://localhost:6379
+TELEGRAM_BOT_TOKEN=
+FINNHUB_API_KEY=
+ANTHROPIC_API_KEY=
+API_AUTH_TOKEN=
+POSTGRES_USER=trading
+POSTGRES_PASSWORD=trading
+POSTGRES_DB=trading
+
+# === CONFIGURATION ===
+ALPACA_PAPER=true
+DRY_RUN=true
+LOG_LEVEL=INFO
+LOG_FORMAT=json
+HEALTH_PORT=8080
+# ... (keep existing config vars)
+
+# === API SECURITY ===
+CORS_ORIGINS=http://localhost:3000
+```
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add docker-compose.yml .env.example
+git commit -m "fix: move hardcoded postgres credentials to .env, annotate secrets"
+```
+
+---
+
+### Task 7: Dockerfile Optimization
+
+**Files:**
+- Create: `.dockerignore`
+- Modify: All 7 Dockerfiles in `services/*/Dockerfile`
+
+- [ ] **Step 1: Create .dockerignore**
+
+Create `.dockerignore` at project root:
+
+```
+__pycache__
+*.pyc
+*.pyo
+.git
+.github
+.venv
+.env
+.env.*
+!.env.example
+tests/
+docs/
+*.md
+.ruff_cache
+.pytest_cache
+.mypy_cache
+monitoring/
+scripts/
+cli/
+```
+
+- [ ] **Step 2: Update data-collector Dockerfile**
+
+Replace `services/data-collector/Dockerfile`:
+
+```dockerfile
+FROM python:3.12-slim AS builder
+WORKDIR /app
+COPY shared/ shared/
+RUN pip install --no-cache-dir ./shared
+COPY services/data-collector/ services/data-collector/
+RUN pip install --no-cache-dir ./services/data-collector
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+ENV PYTHONPATH=/app
+USER appuser
+CMD ["python", "-m", "data_collector.main"]
+```
+
+- [ ] **Step 3: Update all other Dockerfiles with same pattern**
+
+Apply the same multi-stage + non-root pattern to:
+- `services/strategy-engine/Dockerfile` (also copies strategies/)
+- `services/order-executor/Dockerfile`
+- `services/portfolio-manager/Dockerfile`
+- `services/api/Dockerfile` (also copies strategies/, uses uvicorn CMD)
+- `services/news-collector/Dockerfile` (also runs nltk download)
+- `services/backtester/Dockerfile` (also copies strategies/)
+
+For **strategy-engine** Dockerfile:
+```dockerfile
+FROM python:3.12-slim AS builder
+WORKDIR /app
+COPY shared/ shared/
+RUN pip install --no-cache-dir ./shared
+COPY services/strategy-engine/ services/strategy-engine/
+RUN pip install --no-cache-dir ./services/strategy-engine
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY services/strategy-engine/strategies/ /app/strategies/
+ENV PYTHONPATH=/app
+USER appuser
+CMD ["python", "-m", "strategy_engine.main"]
+```
+
+For **news-collector** Dockerfile:
+```dockerfile
+FROM python:3.12-slim AS builder
+WORKDIR /app
+COPY shared/ shared/
+RUN pip install --no-cache-dir ./shared
+COPY services/news-collector/ services/news-collector/
+RUN pip install --no-cache-dir ./services/news-collector
+RUN python -c "import nltk; nltk.download('vader_lexicon', download_dir='/usr/local/nltk_data')"
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY --from=builder /usr/local/nltk_data /usr/local/nltk_data
+ENV PYTHONPATH=/app
+USER appuser
+CMD ["python", "-m", "news_collector.main"]
+```
+
+For **api** Dockerfile:
+```dockerfile
+FROM python:3.12-slim AS builder
+WORKDIR /app
+COPY shared/ shared/
+RUN pip install --no-cache-dir ./shared
+COPY services/api/ services/api/
+RUN pip install --no-cache-dir ./services/api
+COPY services/strategy-engine/ services/strategy-engine/
+RUN pip install --no-cache-dir ./services/strategy-engine
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY services/strategy-engine/strategies/ /app/strategies/
+ENV PYTHONPATH=/app STRATEGIES_DIR=/app/strategies
+USER appuser
+CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000", "--timeout-graceful-shutdown", "30"]
+```
+
+For **order-executor**, **portfolio-manager**, **backtester** — same pattern as data-collector, adjusting the service name and CMD.
+
+- [ ] **Step 4: Verify Docker build works**
+
+Run: `docker compose build --quiet`
+Expected: All images build successfully
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add .dockerignore services/*/Dockerfile
+git commit -m "feat: optimize Dockerfiles with multi-stage builds, non-root user, .dockerignore"
+```
+
+---
+
+### Task 8: Database Index Migration
+
+**Files:**
+- Create: `shared/alembic/versions/003_add_missing_indexes.py`
+
+- [ ] **Step 1: Create migration file**
+
+Create `shared/alembic/versions/003_add_missing_indexes.py`:
+
+```python
+"""Add missing indexes for common query patterns.
+
+Revision ID: 003
+Revises: 002
+"""
+
+from alembic import op
+
+revision = "003"
+down_revision = "002"
+
+
+def upgrade():
+ op.create_index("idx_signals_symbol_created", "signals", ["symbol", "created_at"])
+ op.create_index("idx_orders_symbol_status_created", "orders", ["symbol", "status", "created_at"])
+ op.create_index("idx_trades_order_id", "trades", ["order_id"])
+ op.create_index("idx_trades_symbol_traded", "trades", ["symbol", "traded_at"])
+ op.create_index("idx_portfolio_snapshots_at", "portfolio_snapshots", ["snapshot_at"])
+ op.create_index("idx_symbol_scores_symbol", "symbol_scores", ["symbol"], unique=True)
+
+
+def downgrade():
+ op.drop_index("idx_symbol_scores_symbol", table_name="symbol_scores")
+ op.drop_index("idx_portfolio_snapshots_at", table_name="portfolio_snapshots")
+ op.drop_index("idx_trades_symbol_traded", table_name="trades")
+ op.drop_index("idx_trades_order_id", table_name="trades")
+ op.drop_index("idx_orders_symbol_status_created", table_name="orders")
+ op.drop_index("idx_signals_symbol_created", table_name="signals")
+```
+
+- [ ] **Step 2: Verify migration runs (requires infra)**
+
+Run: `make infra && cd shared && alembic upgrade head`
+Expected: Migration 003 applied successfully
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add shared/alembic/versions/003_add_missing_indexes.py
+git commit -m "feat: add missing DB indexes for signals, orders, trades, snapshots"
+```
+
+---
+
+### Task 9: Docker Compose Resource Limits & Networks
+
+**Files:**
+- Modify: `docker-compose.yml`
+
+- [ ] **Step 1: Add networks and resource limits**
+
+Add to `docker-compose.yml` at bottom:
+
+```yaml
+networks:
+ internal:
+ driver: bridge
+ monitoring:
+ driver: bridge
+```
+
+Add `networks: [internal]` to all application services (redis, postgres, data-collector, strategy-engine, order-executor, portfolio-manager, api, news-collector).
+
+Add `networks: [internal, monitoring]` to prometheus, grafana. Add `networks: [monitoring]` to loki, promtail.
+
+Add to each application service:
+
+```yaml
+ deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
+```
+
+For strategy-engine and backtester, use `memory: 1G` instead.
+
+- [ ] **Step 2: Verify compose config is valid**
+
+Run: `docker compose config --quiet`
+Expected: No errors
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add docker-compose.yml
+git commit -m "feat: add resource limits and network isolation to docker-compose"
+```
+
+---
+
+## Phase 3: Service-Level Improvements
+
+### Task 10: Graceful Shutdown for All Services
+
+**Files:**
+- Modify: `services/data-collector/src/data_collector/main.py`
+- Modify: `services/strategy-engine/src/strategy_engine/main.py`
+- Modify: `services/order-executor/src/order_executor/main.py`
+- Modify: `services/portfolio-manager/src/portfolio_manager/main.py`
+- Modify: `services/news-collector/src/news_collector/main.py`
+- Modify: `services/api/src/trading_api/main.py`
+
+- [ ] **Step 1: Create a shared shutdown helper**
+
+Add to `shared/src/shared/shutdown.py`:
+
+```python
+"""Graceful shutdown utilities for services."""
+
+import asyncio
+import logging
+import signal
+
+logger = logging.getLogger(__name__)
+
+
+class GracefulShutdown:
+ """Manages graceful shutdown via SIGTERM/SIGINT signals."""
+
+ def __init__(self) -> None:
+ self._event = asyncio.Event()
+
+ @property
+ def is_shutting_down(self) -> bool:
+ return self._event.is_set()
+
+ async def wait(self) -> None:
+ await self._event.wait()
+
+ def trigger(self) -> None:
+ logger.info("shutdown_signal_received")
+ self._event.set()
+
+ def install_handlers(self) -> None:
+ loop = asyncio.get_running_loop()
+ for sig in (signal.SIGTERM, signal.SIGINT):
+ loop.add_signal_handler(sig, self.trigger)
+```
+
+- [ ] **Step 2: Add shutdown to data-collector main loop**
+
+In `services/data-collector/src/data_collector/main.py`, add at the start of `run()`:
+
+```python
+from shared.shutdown import GracefulShutdown
+
+shutdown = GracefulShutdown()
+shutdown.install_handlers()
+```
+
+Replace the main `while True` loop condition with `while not shutdown.is_shutting_down`.
+
+- [ ] **Step 3: Apply same pattern to all other services**
+
+For each service's `main.py`, add `GracefulShutdown` import, install handlers at start of `run()`, and replace infinite loops with `while not shutdown.is_shutting_down`.
+
+For strategy-engine: also cancel tasks on shutdown.
+For portfolio-manager: also cancel snapshot_loop task.
+For news-collector: also cancel all collector loop tasks.
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest -v`
+Expected: All tests PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add shared/src/shared/shutdown.py services/*/src/*/main.py
+git commit -m "feat: add graceful shutdown with SIGTERM/SIGINT handlers to all services"
+```
+
+---
+
+### Task 11: Exception Handling Specialization
+
+**Files:**
+- Modify: All service `main.py` files
+- Modify: `shared/src/shared/db.py`
+
+- [ ] **Step 1: Specialize exceptions in data-collector/main.py**
+
+Replace broad `except Exception` blocks. For example, in the fetch loop:
+
+```python
+# Before:
+except Exception as exc:
+ log.warning("fetch_bar_failed", symbol=symbol, error=str(exc))
+
+# After:
+except (ConnectionError, TimeoutError, aiohttp.ClientError) as exc:
+ log.warning("fetch_bar_network_error", symbol=symbol, error=str(exc))
+except (ValueError, KeyError) as exc:
+ log.warning("fetch_bar_parse_error", symbol=symbol, error=str(exc))
+except Exception as exc:
+ log.error("fetch_bar_unexpected", symbol=symbol, error=str(exc), exc_info=True)
+```
+
+- [ ] **Step 2: Specialize exceptions in strategy-engine, order-executor, portfolio-manager, news-collector**
+
+Apply the same pattern: network errors → warning + retry, parse errors → warning + skip, unexpected → error + exc_info.
+
+- [ ] **Step 3: Specialize exceptions in db.py**
+
+In `shared/src/shared/db.py`, the transaction pattern can distinguish:
+
+```python
+except (asyncpg.PostgresError, sqlalchemy.exc.OperationalError) as exc:
+ await session.rollback()
+ logger.error("db_operation_error", error=str(exc))
+ raise
+except Exception:
+ await session.rollback()
+ raise
+```
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest -v`
+Expected: All tests PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add services/*/src/*/main.py shared/src/shared/db.py
+git commit -m "refactor: specialize exception handling across all services"
+```
+
+---
+
+### Task 12: Fix Stock Selector (Bug Fix + Dedup + Session Reuse)
+
+**Files:**
+- Modify: `services/strategy-engine/src/strategy_engine/stock_selector.py`
+- Modify: `services/strategy-engine/tests/test_stock_selector.py` (if exists, otherwise create)
+
+- [ ] **Step 1: Fix the critical bug on line 217**
+
+In `stock_selector.py` line 217, replace:
+```python
+self._session = anthropic_model
+```
+with:
+```python
+self._model = anthropic_model
+```
+
+- [ ] **Step 2: Extract common JSON parsing function**
+
+Replace the duplicate parsing logic. Add at module level (replacing `_parse_llm_selections`):
+
+```python
+def _extract_json_array(text: str) -> list[dict] | None:
+ """Extract a JSON array from text that may contain markdown code blocks."""
+ code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
+ if code_block:
+ raw = code_block.group(1)
+ else:
+ array_match = re.search(r"\[.*\]", text, re.DOTALL)
+ if array_match:
+ raw = array_match.group(0)
+ else:
+ raw = text.strip()
+
+ try:
+ data = json.loads(raw)
+ if isinstance(data, list):
+ return [item for item in data if isinstance(item, dict)]
+ return None
+ except (json.JSONDecodeError, TypeError):
+ return None
+
+
+def _parse_llm_selections(text: str) -> list[SelectedStock]:
+ """Parse LLM response into SelectedStock list."""
+ items = _extract_json_array(text)
+ if items is None:
+ return []
+ selections = []
+ for item in items:
+ try:
+ selections.append(
+ SelectedStock(
+ symbol=item["symbol"],
+ side=OrderSide(item["side"]),
+ conviction=float(item["conviction"]),
+ reason=item.get("reason", ""),
+ key_news=item.get("key_news", []),
+ )
+ )
+ except (KeyError, ValueError) as e:
+ logger.warning("Skipping invalid selection item: %s", e)
+ return selections
+```
+
+Update `LLMCandidateSource._parse_candidates()` to use `_extract_json_array`:
+
+```python
+ def _parse_candidates(self, text: str) -> list[Candidate]:
+ items = _extract_json_array(text)
+ if items is None:
+ return []
+ candidates = []
+ for item in items:
+ try:
+ direction_str = item.get("direction", "BUY")
+ direction = OrderSide(direction_str)
+ except ValueError:
+ direction = None
+ candidates.append(
+ Candidate(
+ symbol=item["symbol"],
+ source="llm",
+ direction=direction,
+ score=float(item.get("score", 0.5)),
+ reason=item.get("reason", ""),
+ )
+ )
+ return candidates
+```
+
+- [ ] **Step 3: Add session reuse to StockSelector**
+
+Add `_http_session` to `StockSelector.__init__()`:
+
+```python
+self._http_session: aiohttp.ClientSession | None = None
+```
+
+Add helper method:
+
+```python
+async def _ensure_session(self) -> aiohttp.ClientSession:
+ if self._http_session is None or self._http_session.closed:
+ self._http_session = aiohttp.ClientSession()
+ return self._http_session
+
+async def close(self) -> None:
+ if self._http_session and not self._http_session.closed:
+ await self._http_session.close()
+```
+
+Replace `async with aiohttp.ClientSession() as session:` in both `LLMCandidateSource.get_candidates()` and `StockSelector._llm_final_select()` with session reuse. For `LLMCandidateSource`, accept an optional session parameter. For `StockSelector._llm_final_select()`, use `self._ensure_session()`.
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest services/strategy-engine/tests/ -v`
+Expected: All tests PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add services/strategy-engine/src/strategy_engine/stock_selector.py
+git commit -m "fix: fix model attr bug, deduplicate LLM parsing, reuse aiohttp sessions"
+```
+
+---
+
+## Phase 4: API Security
+
+### Task 13: Bearer Token Authentication
+
+**Files:**
+- Create: `services/api/src/trading_api/dependencies/__init__.py`
+- Create: `services/api/src/trading_api/dependencies/auth.py`
+- Modify: `services/api/src/trading_api/main.py`
+
+- [ ] **Step 1: Create auth dependency**
+
+Create `services/api/src/trading_api/dependencies/__init__.py` (empty file).
+
+Create `services/api/src/trading_api/dependencies/auth.py`:
+
+```python
+"""Bearer token authentication dependency."""
+
+import logging
+
+from fastapi import Depends, HTTPException, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+
+from shared.config import Settings
+
+logger = logging.getLogger(__name__)
+
+_security = HTTPBearer(auto_error=False)
+_settings = Settings()
+
+
+async def verify_token(
+ credentials: HTTPAuthorizationCredentials | None = Depends(_security),
+) -> None:
+ """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured."""
+ token = _settings.api_auth_token.get_secret_value()
+ if not token:
+ return # Auth disabled in dev mode
+
+ if credentials is None or credentials.credentials != token:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing authentication token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+```
+
+- [ ] **Step 2: Apply auth to all API routes**
+
+In `services/api/src/trading_api/main.py`, add auth dependency to routers:
+
+```python
+from trading_api.dependencies.auth import verify_token
+from fastapi import Depends
+
+app.include_router(portfolio_router, prefix="/api/v1/portfolio", dependencies=[Depends(verify_token)])
+app.include_router(orders_router, prefix="/api/v1/orders", dependencies=[Depends(verify_token)])
+app.include_router(strategies_router, prefix="/api/v1/strategies", dependencies=[Depends(verify_token)])
+```
+
+Log a warning on startup if token is empty:
+
+```python
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ cfg = Settings()
+ if not cfg.api_auth_token.get_secret_value():
+ logger.warning("API_AUTH_TOKEN not set; API authentication is disabled")
+ # ... rest of lifespan
+```
+
+- [ ] **Step 3: Write tests for auth**
+
+Add to `services/api/tests/test_auth.py`:
+
+```python
+"""Tests for API authentication."""
+
+from unittest.mock import patch
+
+import pytest
+from httpx import ASGITransport, AsyncClient
+
+from trading_api.main import app
+
+
+class TestAuth:
+ @patch("trading_api.dependencies.auth._settings")
+ async def test_rejects_missing_token_when_configured(self, mock_settings):
+ from pydantic import SecretStr
+ mock_settings.api_auth_token = SecretStr("test-token")
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
+ resp = await ac.get("/api/v1/portfolio/positions")
+ assert resp.status_code == 401
+
+ @patch("trading_api.dependencies.auth._settings")
+ async def test_accepts_valid_token(self, mock_settings):
+ from pydantic import SecretStr
+ mock_settings.api_auth_token = SecretStr("test-token")
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
+ resp = await ac.get(
+ "/api/v1/portfolio/positions",
+ headers={"Authorization": "Bearer test-token"},
+ )
+ # May fail with 500 if DB not available, but should NOT be 401
+ assert resp.status_code != 401
+```
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest services/api/tests/test_auth.py -v`
+Expected: PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add services/api/src/trading_api/dependencies/ services/api/src/trading_api/main.py services/api/tests/test_auth.py
+git commit -m "feat: add Bearer token authentication to API endpoints"
+```
+
+---
+
+### Task 14: CORS & Rate Limiting
+
+**Files:**
+- Modify: `services/api/src/trading_api/main.py`
+- Modify: `services/api/pyproject.toml`
+
+- [ ] **Step 1: Add slowapi dependency**
+
+Already done in Task 5 (`services/api/pyproject.toml` has `slowapi>=0.1.9,<1`).
+
+- [ ] **Step 2: Add CORS and rate limiting to main.py**
+
+In `services/api/src/trading_api/main.py`:
+
+```python
+from fastapi.middleware.cors import CORSMiddleware
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.util import get_remote_address
+from slowapi.errors import RateLimitExceeded
+
+from shared.config import Settings
+
+cfg = Settings()
+
+limiter = Limiter(key_func=get_remote_address)
+app = FastAPI(title="Trading Platform API")
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=cfg.cors_origins.split(","),
+ allow_methods=["GET", "POST"],
+ allow_headers=["Authorization", "Content-Type"],
+)
+```
+
+- [ ] **Step 3: Add rate limits to order endpoints**
+
+In `services/api/src/trading_api/routers/orders.py`:
+
+```python
+from slowapi import Limiter
+from slowapi.util import get_remote_address
+
+limiter = Limiter(key_func=get_remote_address)
+
+@router.get("/")
+@limiter.limit("60/minute")
+async def get_orders(request: Request, limit: int = 50):
+ ...
+```
+
+- [ ] **Step 4: Run tests**
+
+Run: `pytest services/api/tests/ -v`
+Expected: PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add services/api/src/trading_api/main.py services/api/src/trading_api/routers/
+git commit -m "feat: add CORS middleware and rate limiting to API"
+```
+
+---
+
+### Task 15: API Input Validation & Response Models
+
+**Files:**
+- Modify: `services/api/src/trading_api/routers/portfolio.py`
+- Modify: `services/api/src/trading_api/routers/orders.py`
+
+- [ ] **Step 1: Add Query validation to portfolio.py**
+
+```python
+from fastapi import Query
+
+@router.get("/snapshots")
+async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)):
+ ...
+```
+
+- [ ] **Step 2: Add Query validation to orders.py**
+
+```python
+from fastapi import Query
+
+@router.get("/")
+async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)):
+ ...
+
+@router.get("/signals")
+async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)):
+ ...
+```
+
+- [ ] **Step 3: Run tests**
+
+Run: `pytest services/api/tests/ -v`
+Expected: PASS
+
+- [ ] **Step 4: Commit**
+
+```bash
+git add services/api/src/trading_api/routers/
+git commit -m "feat: add input validation with Query bounds to API endpoints"
+```
+
+---
+
+## Phase 5: Operational Maturity
+
+### Task 16: Enhanced Ruff Configuration
+
+**Files:**
+- Modify: `pyproject.toml:12-14`
+
+- [ ] **Step 1: Update ruff config in pyproject.toml**
+
+Replace the ruff section in root `pyproject.toml`:
+
+```toml
+[tool.ruff]
+target-version = "py312"
+line-length = 100
+
+[tool.ruff.lint]
+select = ["E", "W", "F", "I", "B", "UP", "ASYNC", "PERF", "C4", "RUF"]
+ignore = ["E501"]
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*" = ["F841"]
+"*/tests/*" = ["F841"]
+
+[tool.ruff.lint.isort]
+known-first-party = ["shared"]
+```
+
+- [ ] **Step 2: Auto-fix existing violations**
+
+Run: `ruff check --fix . && ruff format .`
+Expected: Fixes applied
+
+- [ ] **Step 3: Verify no remaining errors**
+
+Run: `ruff check . && ruff format --check .`
+Expected: No errors
+
+- [ ] **Step 4: Run tests to verify no regressions**
+
+Run: `pytest -v`
+Expected: All tests PASS
+
+- [ ] **Step 5: Commit**
+
+```bash
+git add pyproject.toml
+git commit -m "chore: enhance ruff lint rules with ASYNC, bugbear, isort, pyupgrade"
+```
+
+Then commit auto-fixes separately:
+
+```bash
+git add -A
+git commit -m "style: auto-fix lint violations from enhanced ruff rules"
+```
+
+---
+
+### Task 17: GitHub Actions CI Pipeline
+
+**Files:**
+- Create: `.github/workflows/ci.yml`
+
+- [ ] **Step 1: Create CI workflow**
+
+Create `.github/workflows/ci.yml`:
+
+```yaml
+name: CI
+
+on:
+ push:
+ branches: [master]
+ pull_request:
+ branches: [master]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ - run: pip install ruff
+ - run: ruff check .
+ - run: ruff format --check .
+
+ test:
+ runs-on: ubuntu-latest
+ services:
+ redis:
+ image: redis:7-alpine
+ ports: [6379:6379]
+ options: >-
+ --health-cmd "redis-cli ping"
+ --health-interval 5s
+ --health-timeout 3s
+ --health-retries 5
+ postgres:
+ image: postgres:16-alpine
+ env:
+ POSTGRES_USER: trading
+ POSTGRES_PASSWORD: trading
+ POSTGRES_DB: trading
+ ports: [5432:5432]
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 5s
+ --health-timeout 3s
+ --health-retries 5
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ - run: |
+ pip install -e shared/[dev]
+ pip install -e services/strategy-engine/[dev]
+ pip install -e services/data-collector/[dev]
+ pip install -e services/order-executor/[dev]
+ pip install -e services/portfolio-manager/[dev]
+ pip install -e services/news-collector/[dev]
+ pip install -e services/api/[dev]
+ pip install -e services/backtester/[dev]
+ pip install pytest-cov
+ - run: pytest -v --cov=shared/src --cov=services --cov-report=xml --cov-report=term-missing
+ env:
+ DATABASE_URL: postgresql+asyncpg://trading:trading@localhost:5432/trading
+ REDIS_URL: redis://localhost:6379
+ - uses: actions/upload-artifact@v4
+ with:
+ name: coverage-report
+ path: coverage.xml
+
+ docker:
+ runs-on: ubuntu-latest
+ needs: [lint, test]
+ if: github.ref == 'refs/heads/master'
+ steps:
+ - uses: actions/checkout@v4
+ - run: docker compose build --quiet
+```
+
+- [ ] **Step 2: Commit**
+
+```bash
+mkdir -p .github/workflows
+git add .github/workflows/ci.yml
+git commit -m "feat: add GitHub Actions CI pipeline with lint, test, docker build"
+```
+
+---
+
+### Task 18: Prometheus Alerting Rules
+
+**Files:**
+- Create: `monitoring/prometheus/alert_rules.yml`
+- Modify: `monitoring/prometheus.yml`
+
+- [ ] **Step 1: Create alert rules**
+
+Create `monitoring/prometheus/alert_rules.yml`:
+
+```yaml
+groups:
+ - name: trading-platform
+ rules:
+ - alert: ServiceDown
+ expr: up == 0
+ for: 1m
+ labels:
+ severity: critical
+ annotations:
+ summary: "Service {{ $labels.job }} is down"
+ description: "{{ $labels.instance }} has been unreachable for 1 minute."
+
+ - alert: HighErrorRate
+ expr: rate(errors_total[5m]) > 10
+ for: 2m
+ labels:
+ severity: warning
+ annotations:
+ summary: "High error rate on {{ $labels.job }}"
+ description: "Error rate is {{ $value }} errors/sec over 5 minutes."
+
+ - alert: HighProcessingLatency
+ expr: histogram_quantile(0.95, rate(processing_seconds_bucket[5m])) > 5
+ for: 5m
+ labels:
+ severity: warning
+ annotations:
+ summary: "High p95 latency on {{ $labels.job }}"
+ description: "95th percentile processing time is {{ $value }}s."
+```
+
+- [ ] **Step 2: Reference alert rules in prometheus.yml**
+
+In `monitoring/prometheus.yml`, add after `global:`:
+
+```yaml
+rule_files:
+ - "/etc/prometheus/alert_rules.yml"
+```
+
+Update `docker-compose.yml` prometheus service to mount the file:
+
+```yaml
+ prometheus:
+ volumes:
+ - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
+ - ./monitoring/prometheus/alert_rules.yml:/etc/prometheus/alert_rules.yml
+```
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add monitoring/prometheus/alert_rules.yml monitoring/prometheus.yml docker-compose.yml
+git commit -m "feat: add Prometheus alerting rules for service health, errors, latency"
+```
+
+---
+
+### Task 19: Code Coverage Configuration
+
+**Files:**
+- Modify: `pyproject.toml`
+
+- [ ] **Step 1: Add pytest-cov config**
+
+Add to `pyproject.toml`:
+
+```toml
+[tool.coverage.run]
+branch = true
+source = ["shared/src", "services"]
+omit = ["*/tests/*", "*/alembic/*"]
+
+[tool.coverage.report]
+fail_under = 60
+show_missing = true
+exclude_lines = [
+ "pragma: no cover",
+ "if __name__",
+ "if TYPE_CHECKING",
+]
+```
+
+Update pytest addopts:
+```toml
+[tool.pytest.ini_options]
+asyncio_mode = "auto"
+testpaths = ["shared/tests", "services", "cli/tests", "tests"]
+addopts = "--import-mode=importlib"
+```
+
+Note: `--cov` flags are passed explicitly in CI, not in addopts (to avoid slowing local dev).
+
+- [ ] **Step 2: Verify coverage works**
+
+Run: `pip install pytest-cov && pytest --cov=shared/src --cov-report=term-missing`
+Expected: Coverage report printed, no errors
+
+- [ ] **Step 3: Commit**
+
+```bash
+git add pyproject.toml
+git commit -m "chore: add pytest-cov configuration with 60% minimum coverage threshold"
+```
+
+---
+
+## Summary
+
+| Phase | Tasks | Estimated Commits |
+|-------|-------|-------------------|
+| 1: Shared Library | Tasks 1-5 | 5 commits |
+| 2: Infrastructure | Tasks 6-9 | 4 commits |
+| 3: Service Fixes | Tasks 10-12 | 3 commits |
+| 4: API Security | Tasks 13-15 | 3 commits |
+| 5: Operations | Tasks 16-19 | 5 commits |
+| **Total** | **19 tasks** | **~20 commits** |
diff --git a/docs/superpowers/specs/2026-04-02-platform-upgrade-design.md b/docs/superpowers/specs/2026-04-02-platform-upgrade-design.md
new file mode 100644
index 0000000..9c84e10
--- /dev/null
+++ b/docs/superpowers/specs/2026-04-02-platform-upgrade-design.md
@@ -0,0 +1,257 @@
+# Platform Upgrade Design Spec
+
+**Date**: 2026-04-02
+**Approach**: Bottom-Up (shared library → infra → services → API security → operations)
+
+---
+
+## Phase 1: Shared Library Hardening
+
+### 1-1. Resilience Module (`shared/src/shared/resilience.py`)
+Currently empty. Implement:
+- **`retry_async()`** — tenacity-based exponential backoff + jitter decorator. Configurable max retries (default 3), base delay (1s), max delay (30s).
+- **`CircuitBreaker`** — Tracks consecutive failures. Opens after N failures (default 5), stays open for configurable cooldown (default 60s), transitions to half-open to test recovery.
+- **`timeout()`** — asyncio-based timeout wrapper. Raises `TimeoutError` after configurable duration.
+- All decorators composable: `@retry_async() @circuit_breaker() async def call_api():`
+
+### 1-2. DB Connection Pooling (`shared/src/shared/db.py`)
+Add to `create_async_engine()`:
+- `pool_size=20` (configurable via `DB_POOL_SIZE`)
+- `max_overflow=10` (configurable via `DB_MAX_OVERFLOW`)
+- `pool_pre_ping=True` (verify connections before use)
+- `pool_recycle=3600` (recycle stale connections)
+Add corresponding fields to `Settings`.
+
+### 1-3. Redis Resilience (`shared/src/shared/broker.py`)
+- Add to `redis.asyncio.from_url()`: `socket_keepalive=True`, `health_check_interval=30`, `retry_on_timeout=True`
+- Wrap `publish()`, `read_group()`, `ensure_group()` with `@retry_async()` from resilience module
+- Add `reconnect()` method for connection loss recovery
+
+### 1-4. Config Validation (`shared/src/shared/config.py`)
+- Add `field_validator` for business logic: `risk_max_position_size > 0`, `health_port` in 1024-65535, `log_level` in valid set
+- Change secret fields to `SecretStr`: `alpaca_api_key`, `alpaca_api_secret`, `database_url`, `redis_url`, `telegram_bot_token`, `anthropic_api_key`, `finnhub_api_key`
+- Update all consumers to call `.get_secret_value()` where needed
+
+### 1-5. Dependency Pinning
+All `pyproject.toml` files: add upper bounds.
+Examples:
+- `pydantic>=2.8,<3`
+- `redis>=5.0,<6`
+- `sqlalchemy[asyncio]>=2.0,<3`
+- `numpy>=1.26,<3`
+- `pandas>=2.1,<3`
+- `anthropic>=0.40,<1`
+Run `uv lock` to generate lock file.
+
+---
+
+## Phase 2: Infrastructure Hardening
+
+### 2-1. Docker Secrets & Environment
+- Remove hardcoded `POSTGRES_USER: trading` / `POSTGRES_PASSWORD: trading` from `docker-compose.yml`
+- Reference via `${POSTGRES_USER}` / `${POSTGRES_PASSWORD}` from `.env`
+- Add comments in `.env.example` marking secret vs config variables
+
+### 2-2. Dockerfile Optimization (all 7 services)
+Pattern for each Dockerfile:
+```dockerfile
+# Stage 1: builder
+FROM python:3.12-slim AS builder
+WORKDIR /app
+COPY shared/pyproject.toml shared/setup.cfg shared/
+COPY shared/src/ shared/src/
+RUN pip install --no-cache-dir ./shared
+COPY services/<name>/pyproject.toml services/<name>/
+COPY services/<name>/src/ services/<name>/src/
+RUN pip install --no-cache-dir ./services/<name>
+
+# Stage 2: runtime
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+USER appuser
+CMD ["python", "-m", "<module>.main"]
+```
+
+Create root `.dockerignore`:
+```
+__pycache__
+*.pyc
+.git
+.venv
+.env
+tests/
+docs/
+*.md
+.ruff_cache
+```
+
+### 2-3. Database Index Migration (`003_add_missing_indexes.py`)
+New Alembic migration adding:
+- `idx_signals_symbol_created` on `signals(symbol, created_at)`
+- `idx_orders_symbol_status_created` on `orders(symbol, status, created_at)`
+- `idx_trades_order_id` on `trades(order_id)`
+- `idx_trades_symbol_traded` on `trades(symbol, traded_at)`
+- `idx_portfolio_snapshots_at` on `portfolio_snapshots(snapshot_at)`
+- `idx_symbol_scores_symbol` unique on `symbol_scores(symbol)`
+
+### 2-4. Docker Compose Resource Limits
+Add to each service:
+```yaml
+deploy:
+ resources:
+ limits:
+ memory: 512M
+ cpus: '1.0'
+```
+Strategy-engine and backtester get `memory: 1G` (pandas/numpy usage).
+
+Add explicit networks:
+```yaml
+networks:
+ internal:
+ driver: bridge
+ monitoring:
+ driver: bridge
+```
+
+---
+
+## Phase 3: Service-Level Improvements
+
+### 3-1. Graceful Shutdown (all services)
+Add to each service's `main()`:
+```python
+shutdown_event = asyncio.Event()
+
+def _signal_handler():
+ log.info("shutdown_signal_received")
+ shutdown_event.set()
+
+loop = asyncio.get_event_loop()
+loop.add_signal_handler(signal.SIGTERM, _signal_handler)
+loop.add_signal_handler(signal.SIGINT, _signal_handler)
+```
+Main loops check `shutdown_event.is_set()` to exit gracefully.
+API service: add `--timeout-graceful-shutdown 30` to uvicorn CMD.
+
+### 3-2. Exception Specialization (all services)
+Replace broad `except Exception` with layered handling:
+- `ConnectionError`, `TimeoutError` → retry via resilience module
+- `ValueError`, `KeyError` → log warning, skip item, continue
+- `Exception` → top-level only, `exc_info=True` for full traceback + Telegram alert
+
+Target: reduce 63 broad catches to ~10 top-level safety nets.
+
+### 3-3. LLM Parsing Deduplication (`stock_selector.py`)
+Extract `_extract_json_from_text(text: str) -> list | dict | None`:
+- Tries ```` ```json ``` ```` code block extraction
+- Falls back to `re.search(r"\[.*\]", text, re.DOTALL)`
+- Falls back to raw `json.loads(text.strip())`
+Replace 3 duplicate parsing blocks with single call.
+
+### 3-4. aiohttp Session Reuse (`stock_selector.py`)
+- Add `_session: aiohttp.ClientSession | None = None` to `StockSelector`
+- Lazy-init in `_ensure_session()`, close in `close()`
+- Replace all `async with aiohttp.ClientSession()` with `self._session`
+
+---
+
+## Phase 4: API Security
+
+### 4-1. Bearer Token Authentication
+- Add `api_auth_token: SecretStr = ""` to `Settings`
+- Create `dependencies/auth.py` with `verify_token()` dependency
+- Apply to all `/api/v1/*` routes via router-level `dependencies=[Depends(verify_token)]`
+- If token is empty string → skip auth (dev mode), log warning on startup
+
+### 4-2. CORS Configuration
+```python
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=settings.cors_origins.split(","), # default: "http://localhost:3000"
+ allow_methods=["GET", "POST"],
+ allow_headers=["Authorization", "Content-Type"],
+)
+```
+Add `cors_origins: str = "http://localhost:3000"` to Settings.
+
+### 4-3. Rate Limiting
+- Add `slowapi` dependency
+- Global default: 60 req/min per IP
+- Order-related endpoints: 10 req/min per IP
+- Return `429 Too Many Requests` with `Retry-After` header
+
+### 4-4. Input Validation
+- All `limit` params: `Query(default=50, ge=1, le=1000)`
+- All `days` params: `Query(default=30, ge=1, le=365)`
+- Add Pydantic `response_model` to all endpoints (enables auto OpenAPI docs)
+- Add `symbol` param validation: uppercase, 1-5 chars, alphanumeric
+
+---
+
+## Phase 5: Operational Maturity
+
+### 5-1. GitHub Actions CI/CD
+File: `.github/workflows/ci.yml`
+
+**PR trigger** (`pull_request`):
+1. Install deps (`uv sync`)
+2. Ruff lint + format check
+3. pytest with coverage (`--cov --cov-report=xml`)
+4. Upload coverage to PR comment
+
+**Main push** (`push: branches: [master]`):
+1. Same lint + test
+2. `docker compose build`
+3. (Future: push to registry)
+
+### 5-2. Ruff Rules Enhancement
+```toml
+[tool.ruff.lint]
+select = ["E", "W", "F", "I", "B", "UP", "ASYNC", "PERF", "C4", "RUF"]
+ignore = ["E501"]
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*" = ["F841"]
+
+[tool.ruff.lint.isort]
+known-first-party = ["shared"]
+```
+Run `ruff check --fix .` and `ruff format .` to fix existing violations, commit separately.
+
+### 5-3. Prometheus Alerting
+File: `monitoring/prometheus/alert_rules.yml`
+Rules:
+- `ServiceDown`: `service_up == 0` for 1 min → critical
+- `HighErrorRate`: `rate(errors_total[5m]) > 10` → warning
+- `HighLatency`: `histogram_quantile(0.95, processing_seconds) > 5` → warning
+
+Add Alertmanager config with Telegram webhook (reuse existing bot token).
+Reference alert rules in `monitoring/prometheus.yml`.
+
+### 5-4. Code Coverage
+Add to root `pyproject.toml`:
+```toml
+[tool.pytest.ini_options]
+addopts = "--cov=shared/src --cov=services --cov-report=term-missing"
+
+[tool.coverage.run]
+branch = true
+omit = ["tests/*", "*/alembic/*"]
+
+[tool.coverage.report]
+fail_under = 70
+```
+Add `pytest-cov` to dev dependencies.
+
+---
+
+## Out of Scope
+- Kubernetes/Helm charts (premature — Docker Compose sufficient for current scale)
+- External secrets manager (Vault, AWS SM — overkill for single-machine deployment)
+- OpenTelemetry distributed tracing (add when debugging cross-service issues)
+- API versioning beyond `/api/v1/` prefix
+- Data retention/partitioning (address when data volume becomes an issue)
diff --git a/monitoring/prometheus.yml b/monitoring/prometheus.yml
index b6dc853..e177c9c 100644
--- a/monitoring/prometheus.yml
+++ b/monitoring/prometheus.yml
@@ -1,5 +1,7 @@
global:
scrape_interval: 15s
+rule_files:
+ - "/etc/prometheus/alert_rules.yml"
scrape_configs:
- job_name: "trading-services"
authorization:
diff --git a/monitoring/prometheus/alert_rules.yml b/monitoring/prometheus/alert_rules.yml
new file mode 100644
index 0000000..aca2f1c
--- /dev/null
+++ b/monitoring/prometheus/alert_rules.yml
@@ -0,0 +1,29 @@
+groups:
+ - name: trading-platform
+ rules:
+ - alert: ServiceDown
+ expr: up == 0
+ for: 1m
+ labels:
+ severity: critical
+ annotations:
+ summary: "Service {{ $labels.job }} is down"
+ description: "{{ $labels.instance }} has been unreachable for 1 minute."
+
+ - alert: HighErrorRate
+ expr: rate(errors_total[5m]) > 10
+ for: 2m
+ labels:
+ severity: warning
+ annotations:
+ summary: "High error rate on {{ $labels.job }}"
+ description: "Error rate is {{ $value }} errors/sec over 5 minutes."
+
+ - alert: HighProcessingLatency
+ expr: histogram_quantile(0.95, rate(processing_seconds_bucket[5m])) > 5
+ for: 5m
+ labels:
+ severity: warning
+ annotations:
+ summary: "High p95 latency on {{ $labels.job }}"
+ description: "95th percentile processing time is {{ $value }}s."
diff --git a/pyproject.toml b/pyproject.toml
index 6938778..08150fb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,3 +12,28 @@ addopts = "--import-mode=importlib"
[tool.ruff]
target-version = "py312"
line-length = 100
+
+[tool.ruff.lint]
+select = ["E", "W", "F", "I", "B", "UP", "ASYNC", "PERF", "C4", "RUF"]
+ignore = ["E501", "RUF012", "B008", "ASYNC240"]
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*" = ["F841"]
+"*/tests/*" = ["F841"]
+
+[tool.ruff.lint.isort]
+known-first-party = ["shared"]
+
+[tool.coverage.run]
+branch = true
+source = ["shared/src", "services"]
+omit = ["*/tests/*", "*/alembic/*"]
+
+[tool.coverage.report]
+fail_under = 60
+show_missing = true
+exclude_lines = [
+ "pragma: no cover",
+ "if __name__",
+ "if TYPE_CHECKING",
+]
diff --git a/scripts/backtest_moc.py b/scripts/backtest_moc.py
index 92b426b..9307668 100755
--- a/scripts/backtest_moc.py
+++ b/scripts/backtest_moc.py
@@ -4,11 +4,11 @@
Usage: python scripts/backtest_moc.py
"""
-import sys
import random
-from pathlib import Path
+import sys
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
-from datetime import datetime, timedelta, timezone
+from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "services" / "strategy-engine" / "src"))
@@ -16,10 +16,11 @@ sys.path.insert(0, str(ROOT / "services" / "strategy-engine"))
sys.path.insert(0, str(ROOT / "services" / "backtester" / "src"))
sys.path.insert(0, str(ROOT / "shared" / "src"))
-from shared.models import Candle # noqa: E402
from backtester.engine import BacktestEngine # noqa: E402
from strategies.moc_strategy import MocStrategy # noqa: E402
+from shared.models import Candle # noqa: E402
+
def generate_stock_candles(
symbol: str = "AAPL",
@@ -38,7 +39,7 @@ def generate_stock_candles(
"""
candles = []
price = base_price
- start_date = datetime(2025, 1, 2, tzinfo=timezone.utc) # Start on a Thursday
+ start_date = datetime(2025, 1, 2, tzinfo=UTC) # Start on a Thursday
trading_day = 0
current_date = start_date
@@ -136,27 +137,26 @@ def main():
]
# Parameter grid
- param_sets = []
- for rsi_min in [25, 30, 35]:
- for rsi_max in [55, 60, 65]:
- for sl in [1.5, 2.0, 3.0]:
- for ema in [10, 20]:
- param_sets.append(
- {
- "quantity_pct": 0.2,
- "stop_loss_pct": sl,
- "rsi_min": rsi_min,
- "rsi_max": rsi_max,
- "ema_period": ema,
- "volume_avg_period": 20,
- "min_volume_ratio": 0.8,
- "buy_start_utc": 19,
- "buy_end_utc": 21,
- "sell_start_utc": 14,
- "sell_end_utc": 15,
- "max_positions": 5,
- }
- )
+ param_sets = [
+ {
+ "quantity_pct": 0.2,
+ "stop_loss_pct": sl,
+ "rsi_min": rsi_min,
+ "rsi_max": rsi_max,
+ "ema_period": ema,
+ "volume_avg_period": 20,
+ "min_volume_ratio": 0.8,
+ "buy_start_utc": 19,
+ "buy_end_utc": 21,
+ "sell_start_utc": 14,
+ "sell_end_utc": 15,
+ "max_positions": 5,
+ }
+ for rsi_min in [25, 30, 35]
+ for rsi_max in [55, 60, 65]
+ for sl in [1.5, 2.0, 3.0]
+ for ema in [10, 20]
+ ]
print(f"\nParameter combinations: {len(param_sets)}")
print(f"Stocks: {[s[0] for s in stocks]}")
@@ -234,7 +234,7 @@ def main():
print("\n" + "=" * 60)
print("WORST 3 PARAMETER SETS")
print("=" * 60)
- for _rank, (params, profit, trades, sharpe, _) in enumerate(param_results[-3:], 1):
+ for _rank, (params, profit, _trades, sharpe, _) in enumerate(param_results[-3:], 1):
print(
f" RSI({params['rsi_min']}-{params['rsi_max']}),"
f" SL={params['stop_loss_pct']}%, EMA={params['ema_period']}"
diff --git a/scripts/stock_screener.py b/scripts/stock_screener.py
index 7a5c0ba..7552aa3 100755
--- a/scripts/stock_screener.py
+++ b/scripts/stock_screener.py
@@ -16,7 +16,7 @@ import asyncio
import json
import os
import sys
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
@@ -195,7 +195,7 @@ async def main_async(top_n: int = 5, universe: list[str] | None = None):
print("=" * 60)
print("Daily Stock Screener — MOC Strategy")
- print(f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}")
+ print(f"Date: {datetime.now(UTC).strftime('%Y-%m-%d %H:%M UTC')}")
print(f"Universe: {len(symbols)} stocks")
print("=" * 60)
diff --git a/services/api/Dockerfile b/services/api/Dockerfile
index b942075..93d2b75 100644
--- a/services/api/Dockerfile
+++ b/services/api/Dockerfile
@@ -1,11 +1,18 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/api/ services/api/
RUN pip install --no-cache-dir ./services/api
-COPY services/strategy-engine/strategies/ /app/strategies/
COPY services/strategy-engine/ services/strategy-engine/
RUN pip install --no-cache-dir ./services/strategy-engine
-ENV PYTHONPATH=/app
-CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000"]
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY services/strategy-engine/strategies/ /app/strategies/
+ENV PYTHONPATH=/app STRATEGIES_DIR=/app/strategies
+USER appuser
+CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000", "--timeout-graceful-shutdown", "30"]
diff --git a/services/api/pyproject.toml b/services/api/pyproject.toml
index fd2598d..95099d2 100644
--- a/services/api/pyproject.toml
+++ b/services/api/pyproject.toml
@@ -3,11 +3,7 @@ name = "trading-api"
version = "0.1.0"
description = "REST API for the trading platform"
requires-python = ">=3.12"
-dependencies = [
- "fastapi>=0.110",
- "uvicorn>=0.27",
- "trading-shared",
-]
+dependencies = ["fastapi>=0.110,<1", "uvicorn>=0.27,<1", "slowapi>=0.1.9,<1", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27"]
diff --git a/services/api/src/trading_api/dependencies/__init__.py b/services/api/src/trading_api/dependencies/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/api/src/trading_api/dependencies/__init__.py
diff --git a/services/api/src/trading_api/dependencies/auth.py b/services/api/src/trading_api/dependencies/auth.py
new file mode 100644
index 0000000..a5e76c1
--- /dev/null
+++ b/services/api/src/trading_api/dependencies/auth.py
@@ -0,0 +1,29 @@
+"""Bearer token authentication dependency."""
+
+import logging
+
+from fastapi import Depends, HTTPException, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+
+from shared.config import Settings
+
+logger = logging.getLogger(__name__)
+
+_security = HTTPBearer(auto_error=False)
+_settings = Settings()
+
+
+async def verify_token(
+ credentials: HTTPAuthorizationCredentials | None = Depends(_security),
+) -> None:
+ """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured."""
+ token = _settings.api_auth_token.get_secret_value()
+ if not token:
+ return # Auth disabled in dev mode
+
+ if credentials is None or credentials.credentials != token:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing authentication token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py
index 39f7b43..05c6d2f 100644
--- a/services/api/src/trading_api/main.py
+++ b/services/api/src/trading_api/main.py
@@ -1,33 +1,71 @@
"""Trading Platform REST API."""
+import logging
from contextlib import asynccontextmanager
-from fastapi import FastAPI
+from fastapi import Depends, FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+from slowapi.util import get_remote_address
from shared.config import Settings
from shared.db import Database
+from trading_api.dependencies.auth import verify_token
+from trading_api.routers import orders, portfolio, strategies
-from trading_api.routers import portfolio, orders, strategies
+logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
- app.state.db = Database(settings.database_url)
+ if not settings.api_auth_token.get_secret_value():
+ logger.warning("API_AUTH_TOKEN not set — authentication is disabled")
+ app.state.db = Database(settings.database_url.get_secret_value())
await app.state.db.connect()
yield
await app.state.db.close()
+cfg = Settings()
+
+limiter = Limiter(key_func=get_remote_address)
+
app = FastAPI(
title="Trading Platform API",
version="0.1.0",
lifespan=lifespan,
)
-app.include_router(portfolio.router, prefix="/api/v1/portfolio", tags=["portfolio"])
-app.include_router(orders.router, prefix="/api/v1/orders", tags=["orders"])
-app.include_router(strategies.router, prefix="/api/v1/strategies", tags=["strategies"])
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=cfg.cors_origins.split(","),
+ allow_methods=["GET", "POST"],
+ allow_headers=["Authorization", "Content-Type"],
+)
+
+app.include_router(
+ portfolio.router,
+ prefix="/api/v1/portfolio",
+ tags=["portfolio"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ orders.router,
+ prefix="/api/v1/orders",
+ tags=["orders"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ strategies.router,
+ prefix="/api/v1/strategies",
+ tags=["strategies"],
+ dependencies=[Depends(verify_token)],
+)
@app.get("/health")
diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py
index c69dc10..b664e2a 100644
--- a/services/api/src/trading_api/routers/orders.py
+++ b/services/api/src/trading_api/routers/orders.py
@@ -2,17 +2,23 @@
import logging
-from fastapi import APIRouter, HTTPException, Request
-from shared.sa_models import OrderRow, SignalRow
+from fastapi import APIRouter, HTTPException, Query, Request
+from slowapi import Limiter
+from slowapi.util import get_remote_address
from sqlalchemy import select
+from sqlalchemy.exc import OperationalError
+
+from shared.sa_models import OrderRow, SignalRow
logger = logging.getLogger(__name__)
router = APIRouter()
+limiter = Limiter(key_func=get_remote_address)
@router.get("/")
-async def get_orders(request: Request, limit: int = 50):
+@limiter.limit("60/minute")
+async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)):
"""Get recent orders."""
try:
db = request.app.state.db
@@ -35,13 +41,17 @@ async def get_orders(request: Request, limit: int = 50):
}
for r in rows
]
+ except OperationalError as exc:
+ logger.error("Database error fetching orders: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get orders: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve orders")
+ logger.error("Failed to get orders: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve orders") from exc
@router.get("/signals")
-async def get_signals(request: Request, limit: int = 50):
+@limiter.limit("60/minute")
+async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)):
"""Get recent signals."""
try:
db = request.app.state.db
@@ -62,6 +72,9 @@ async def get_signals(request: Request, limit: int = 50):
}
for r in rows
]
+ except OperationalError as exc:
+ logger.error("Database error fetching signals: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get signals: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve signals")
+ logger.error("Failed to get signals: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve signals") from exc
diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py
index d76d85d..56bee7c 100644
--- a/services/api/src/trading_api/routers/portfolio.py
+++ b/services/api/src/trading_api/routers/portfolio.py
@@ -2,9 +2,11 @@
import logging
-from fastapi import APIRouter, HTTPException, Request
-from shared.sa_models import PositionRow
+from fastapi import APIRouter, HTTPException, Query, Request
from sqlalchemy import select
+from sqlalchemy.exc import OperationalError
+
+from shared.sa_models import PositionRow
logger = logging.getLogger(__name__)
@@ -29,13 +31,16 @@ async def get_positions(request: Request):
}
for r in rows
]
+ except OperationalError as exc:
+ logger.error("Database error fetching positions: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get positions: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve positions")
+ logger.error("Failed to get positions: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve positions") from exc
@router.get("/snapshots")
-async def get_snapshots(request: Request, days: int = 30):
+async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)):
"""Get portfolio snapshots for the last N days."""
try:
db = request.app.state.db
@@ -49,6 +54,9 @@ async def get_snapshots(request: Request, days: int = 30):
}
for s in snapshots
]
+ except OperationalError as exc:
+ logger.error("Database error fetching snapshots: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get snapshots: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve snapshots")
+ logger.error("Failed to get snapshots: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") from exc
diff --git a/services/api/src/trading_api/routers/strategies.py b/services/api/src/trading_api/routers/strategies.py
index 7ddd54e..157094c 100644
--- a/services/api/src/trading_api/routers/strategies.py
+++ b/services/api/src/trading_api/routers/strategies.py
@@ -42,6 +42,9 @@ async def list_strategies():
}
for s in strategies
]
+ except (ImportError, FileNotFoundError) as exc:
+ logger.error("Strategy loading error: %s", exc)
+ raise HTTPException(status_code=503, detail="Strategy engine unavailable") from exc
except Exception as exc:
- logger.error("Failed to list strategies: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to list strategies")
+ logger.error("Failed to list strategies: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to list strategies") from exc
diff --git a/services/api/tests/test_api.py b/services/api/tests/test_api.py
index 669143b..f3b0a47 100644
--- a/services/api/tests/test_api.py
+++ b/services/api/tests/test_api.py
@@ -1,6 +1,7 @@
"""Tests for the REST API."""
from unittest.mock import AsyncMock, patch
+
from fastapi.testclient import TestClient
diff --git a/services/api/tests/test_orders_router.py b/services/api/tests/test_orders_router.py
index 0658619..52252c5 100644
--- a/services/api/tests/test_orders_router.py
+++ b/services/api/tests/test_orders_router.py
@@ -1,10 +1,10 @@
"""Tests for orders API router."""
-import pytest
from unittest.mock import AsyncMock, MagicMock
-from fastapi.testclient import TestClient
-from fastapi import FastAPI
+import pytest
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
from trading_api.routers.orders import router
diff --git a/services/api/tests/test_portfolio_router.py b/services/api/tests/test_portfolio_router.py
index 3bd1b2c..8cd8ff8 100644
--- a/services/api/tests/test_portfolio_router.py
+++ b/services/api/tests/test_portfolio_router.py
@@ -1,11 +1,11 @@
"""Tests for portfolio API router."""
-import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
-from fastapi.testclient import TestClient
-from fastapi import FastAPI
+import pytest
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
from trading_api.routers.portfolio import router
diff --git a/services/backtester/Dockerfile b/services/backtester/Dockerfile
index 9a4f439..1108e42 100644
--- a/services/backtester/Dockerfile
+++ b/services/backtester/Dockerfile
@@ -1,10 +1,17 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/backtester/ services/backtester/
RUN pip install --no-cache-dir ./services/backtester
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
COPY services/strategy-engine/strategies/ /app/strategies/
ENV STRATEGIES_DIR=/app/strategies
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "backtester.main"]
diff --git a/services/backtester/pyproject.toml b/services/backtester/pyproject.toml
index 2601d04..034bcf6 100644
--- a/services/backtester/pyproject.toml
+++ b/services/backtester/pyproject.toml
@@ -3,7 +3,7 @@ name = "backtester"
version = "0.1.0"
description = "Strategy backtesting engine"
requires-python = ">=3.12"
-dependencies = ["pandas>=2.0", "numpy>=1.20", "rich>=13.0", "trading-shared"]
+dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "rich>=13.0,<14", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
diff --git a/services/backtester/src/backtester/engine.py b/services/backtester/src/backtester/engine.py
index b03715d..fcf48f1 100644
--- a/services/backtester/src/backtester/engine.py
+++ b/services/backtester/src/backtester/engine.py
@@ -6,10 +6,9 @@ from dataclasses import dataclass, field
from decimal import Decimal
from typing import Protocol
-from shared.models import Candle, Signal
-
from backtester.metrics import DetailedMetrics, TradeRecord, compute_detailed_metrics
from backtester.simulator import OrderSimulator, SimulatedTrade
+from shared.models import Candle, Signal
class StrategyProtocol(Protocol):
@@ -101,7 +100,7 @@ class BacktestEngine:
final_balance = simulator.balance
if candles:
last_price = candles[-1].close
- for symbol, qty in simulator.positions.items():
+ for qty in simulator.positions.values():
if qty > Decimal("0"):
final_balance += qty * last_price
elif qty < Decimal("0"):
diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py
index a4cea76..dbde00b 100644
--- a/services/backtester/src/backtester/main.py
+++ b/services/backtester/src/backtester/main.py
@@ -17,11 +17,11 @@ _STRATEGIES_DIR = Path(
if _STRATEGIES_DIR.parent not in [Path(p) for p in sys.path]:
sys.path.insert(0, str(_STRATEGIES_DIR.parent))
-from shared.db import Database # noqa: E402
-from shared.models import Candle # noqa: E402
from backtester.config import BacktestConfig # noqa: E402
from backtester.engine import BacktestEngine # noqa: E402
from backtester.reporter import format_report # noqa: E402
+from shared.db import Database # noqa: E402
+from shared.models import Candle # noqa: E402
async def run_backtest() -> str:
@@ -45,7 +45,7 @@ async def run_backtest() -> str:
except Exception as exc:
raise RuntimeError(f"Failed to load strategy '{config.strategy_name}': {exc}") from exc
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
try:
rows = await db.get_candles(config.symbol, config.timeframe, config.candle_limit)
diff --git a/services/backtester/src/backtester/metrics.py b/services/backtester/src/backtester/metrics.py
index 239cb6f..c7b032b 100644
--- a/services/backtester/src/backtester/metrics.py
+++ b/services/backtester/src/backtester/metrics.py
@@ -266,7 +266,7 @@ def compute_detailed_metrics(
largest_win=largest_win,
largest_loss=largest_loss,
avg_holding_period=avg_holding,
- trade_pairs=[p for p in pairs],
+ trade_pairs=list(pairs),
risk_free_rate=risk_free_rate,
recovery_factor=recovery_factor,
max_consecutive_losses=max_consec_losses,
diff --git a/services/backtester/src/backtester/simulator.py b/services/backtester/src/backtester/simulator.py
index 64c88dd..6bce18b 100644
--- a/services/backtester/src/backtester/simulator.py
+++ b/services/backtester/src/backtester/simulator.py
@@ -1,9 +1,8 @@
"""Simulated order executor for backtesting."""
from dataclasses import dataclass, field
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from typing import Optional
from shared.models import OrderSide, Signal
@@ -16,7 +15,7 @@ class SimulatedTrade:
quantity: Decimal
balance_after: Decimal
fee: Decimal = Decimal("0")
- timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
@dataclass
@@ -27,8 +26,8 @@ class OpenPosition:
side: OrderSide # BUY = long, SELL = short
entry_price: Decimal
quantity: Decimal
- stop_loss: Optional[Decimal] = None
- take_profit: Optional[Decimal] = None
+ stop_loss: Decimal | None = None
+ take_profit: Decimal | None = None
class OrderSimulator:
@@ -70,7 +69,7 @@ class OrderSimulator:
remaining: list[OpenPosition] = []
for pos in self.open_positions:
triggered = False
- exit_price: Optional[Decimal] = None
+ exit_price: Decimal | None = None
if pos.side == OrderSide.BUY: # Long position
if pos.stop_loss is not None and candle_low <= pos.stop_loss:
@@ -125,12 +124,12 @@ class OrderSimulator:
def execute(
self,
signal: Signal,
- timestamp: Optional[datetime] = None,
- stop_loss: Optional[Decimal] = None,
- take_profit: Optional[Decimal] = None,
+ timestamp: datetime | None = None,
+ stop_loss: Decimal | None = None,
+ take_profit: Decimal | None = None,
) -> bool:
"""Execute a signal with slippage and fees. Returns True if accepted."""
- ts = timestamp or datetime.now(timezone.utc)
+ ts = timestamp or datetime.now(UTC)
exec_price = self._apply_slippage(signal.price, signal.side)
fee = self._calculate_fee(exec_price, signal.quantity)
diff --git a/services/backtester/src/backtester/walk_forward.py b/services/backtester/src/backtester/walk_forward.py
index c7b7fd8..720ad5e 100644
--- a/services/backtester/src/backtester/walk_forward.py
+++ b/services/backtester/src/backtester/walk_forward.py
@@ -1,11 +1,11 @@
"""Walk-forward analysis for strategy parameter optimization."""
+from collections.abc import Callable
from dataclasses import dataclass, field
from decimal import Decimal
-from typing import Callable
-from shared.models import Candle
from backtester.engine import BacktestEngine, BacktestResult, StrategyProtocol
+from shared.models import Candle
@dataclass
diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py
index 4794e63..f789831 100644
--- a/services/backtester/tests/test_engine.py
+++ b/services/backtester/tests/test_engine.py
@@ -1,20 +1,19 @@
"""Tests for the BacktestEngine."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from unittest.mock import MagicMock
-
-from shared.models import Candle, Signal, OrderSide
-
from backtester.engine import BacktestEngine
+from shared.models import Candle, OrderSide, Signal
+
def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle:
return Candle(
symbol=symbol,
timeframe=timeframe,
- open_time=datetime.now(timezone.utc),
+ open_time=datetime.now(UTC),
open=Decimal(str(price)),
high=Decimal(str(price * 1.01)),
low=Decimal(str(price * 0.99)),
diff --git a/services/backtester/tests/test_metrics.py b/services/backtester/tests/test_metrics.py
index 55f5b6c..13e545e 100644
--- a/services/backtester/tests/test_metrics.py
+++ b/services/backtester/tests/test_metrics.py
@@ -1,17 +1,16 @@
"""Tests for detailed backtest metrics."""
import math
-from datetime import datetime, timedelta, timezone
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
import pytest
-
from backtester.metrics import TradeRecord, compute_detailed_metrics
def _make_trade(side: str, price: str, minutes_offset: int = 0) -> TradeRecord:
return TradeRecord(
- time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(minutes=minutes_offset),
+ time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(minutes=minutes_offset),
symbol="AAPL",
side=side,
price=Decimal(price),
@@ -124,7 +123,7 @@ def test_consecutive_losses():
def test_risk_free_rate_affects_sharpe():
"""Higher risk-free rate should lower Sharpe ratio."""
- base = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ base = datetime(2025, 1, 1, tzinfo=UTC)
trades = [
TradeRecord(
time=base, symbol="AAPL", side="BUY", price=Decimal("100"), quantity=Decimal("1")
@@ -184,7 +183,7 @@ def test_daily_returns_populated():
def test_fee_subtracted_from_pnl():
"""Fees should be subtracted from trade PnL."""
- base = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ base = datetime(2025, 1, 1, tzinfo=UTC)
trades_with_fees = [
TradeRecord(
time=base,
diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py
index 62e2cdb..f85594f 100644
--- a/services/backtester/tests/test_simulator.py
+++ b/services/backtester/tests/test_simulator.py
@@ -1,11 +1,12 @@
"""Tests for the OrderSimulator."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from shared.models import OrderSide, Signal
from backtester.simulator import OrderSimulator
+from shared.models import OrderSide, Signal
+
def make_signal(
symbol: str,
@@ -135,7 +136,7 @@ def test_stop_loss_triggers():
signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1")
sim.execute(signal, stop_loss=Decimal("48000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("50500"),
candle_low=Decimal("47500"), # below stop_loss
@@ -153,7 +154,7 @@ def test_take_profit_triggers():
signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1")
sim.execute(signal, take_profit=Decimal("55000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("56000"), # above take_profit
candle_low=Decimal("50000"),
@@ -171,7 +172,7 @@ def test_stop_not_triggered_within_range():
signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1")
sim.execute(signal, stop_loss=Decimal("48000"), take_profit=Decimal("55000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("52000"),
candle_low=Decimal("49000"),
@@ -212,7 +213,7 @@ def test_short_stop_loss():
signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1")
sim.execute(signal, stop_loss=Decimal("52000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("53000"), # above stop_loss
candle_low=Decimal("49000"),
diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py
index 96abb6e..b1aa12c 100644
--- a/services/backtester/tests/test_walk_forward.py
+++ b/services/backtester/tests/test_walk_forward.py
@@ -1,18 +1,18 @@
"""Tests for walk-forward analysis."""
import sys
-from pathlib import Path
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
-from datetime import datetime, timedelta, timezone
-
+from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine"))
-from shared.models import Candle
from backtester.walk_forward import WalkForwardEngine, WalkForwardResult
from strategies.rsi_strategy import RsiStrategy
+from shared.models import Candle
+
def _generate_candles(n=100, base_price=100.0):
candles = []
@@ -23,7 +23,7 @@ def _generate_candles(n=100, base_price=100.0):
Candle(
symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i),
open=Decimal(str(price)),
high=Decimal(str(price + 5)),
low=Decimal(str(price - 5)),
diff --git a/services/data-collector/Dockerfile b/services/data-collector/Dockerfile
index 8cb8af4..4d154c5 100644
--- a/services/data-collector/Dockerfile
+++ b/services/data-collector/Dockerfile
@@ -1,8 +1,15 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/data-collector/ services/data-collector/
RUN pip install --no-cache-dir ./services/data-collector
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "data_collector.main"]
diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py
index b42b34c..2d44848 100644
--- a/services/data-collector/src/data_collector/main.py
+++ b/services/data-collector/src/data_collector/main.py
@@ -2,6 +2,9 @@
import asyncio
+import aiohttp
+
+from data_collector.config import CollectorConfig
from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
from shared.db import Database
@@ -11,8 +14,7 @@ from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.models import Candle
from shared.notifier import TelegramNotifier
-
-from data_collector.config import CollectorConfig
+from shared.shutdown import GracefulShutdown
# Health check port: base + 0
HEALTH_PORT_OFFSET = 0
@@ -45,8 +47,10 @@ async def fetch_latest_bars(
volume=Decimal(str(bar["v"])),
)
candles.append(candle)
- except Exception as exc:
- log.warning("fetch_bar_failed", symbol=symbol, error=str(exc))
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("fetch_bar_network_error", symbol=symbol, error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("fetch_bar_parse_error", symbol=symbol, error=str(exc))
return candles
@@ -56,18 +60,18 @@ async def run() -> None:
metrics = ServiceMetrics("data_collector")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
alpaca = AlpacaClient(
- api_key=config.alpaca_api_key,
- api_secret=config.alpaca_api_secret,
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
paper=config.alpaca_paper,
)
@@ -83,14 +87,17 @@ async def run() -> None:
symbols = config.symbols
timeframe = config.timeframes[0] if config.timeframes else "1Day"
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("starting", symbols=symbols, timeframe=timeframe, poll_interval=poll_interval)
try:
- while True:
+ while not shutdown.is_shutting_down:
# Check if market is open
try:
is_open = await alpaca.is_market_open()
- except Exception:
+ except (aiohttp.ClientError, ConnectionError, TimeoutError):
is_open = False
if is_open:
@@ -109,7 +116,7 @@ async def run() -> None:
await asyncio.sleep(poll_interval)
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "data-collector")
raise
finally:
diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py
index ffffa40..51f3aee 100644
--- a/services/data-collector/tests/test_storage.py
+++ b/services/data-collector/tests/test_storage.py
@@ -1,19 +1,20 @@
"""Tests for storage module."""
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
-from shared.models import Candle
+import pytest
from data_collector.storage import CandleStorage
+from shared.models import Candle
+
def _make_candle(symbol: str = "AAPL") -> Candle:
return Candle(
symbol=symbol,
timeframe="1m",
- open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC),
open=Decimal("30000"),
high=Decimal("30100"),
low=Decimal("29900"),
diff --git a/services/news-collector/Dockerfile b/services/news-collector/Dockerfile
index a8e5902..7accee2 100644
--- a/services/news-collector/Dockerfile
+++ b/services/news-collector/Dockerfile
@@ -1,9 +1,17 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/news-collector/ services/news-collector/
RUN pip install --no-cache-dir ./services/news-collector
-RUN python -c "import nltk; nltk.download('vader_lexicon', quiet=True)"
+RUN python -c "import nltk; nltk.download('vader_lexicon', download_dir='/usr/local/nltk_data')"
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY --from=builder /usr/local/nltk_data /usr/local/nltk_data
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "news_collector.main"]
diff --git a/services/news-collector/pyproject.toml b/services/news-collector/pyproject.toml
index 14c856a..6e62b70 100644
--- a/services/news-collector/pyproject.toml
+++ b/services/news-collector/pyproject.toml
@@ -3,12 +3,7 @@ name = "news-collector"
version = "0.1.0"
description = "News and sentiment data collector service"
requires-python = ">=3.12"
-dependencies = [
- "trading-shared",
- "feedparser>=6.0",
- "nltk>=3.8",
- "aiohttp>=3.9",
-]
+dependencies = ["trading-shared", "feedparser>=6.0,<7", "nltk>=3.8,<4", "aiohttp>=3.9,<4"]
[project.optional-dependencies]
dev = [
diff --git a/services/news-collector/src/news_collector/collectors/fear_greed.py b/services/news-collector/src/news_collector/collectors/fear_greed.py
index f79f716..42e8f88 100644
--- a/services/news-collector/src/news_collector/collectors/fear_greed.py
+++ b/services/news-collector/src/news_collector/collectors/fear_greed.py
@@ -2,7 +2,6 @@
import logging
from dataclasses import dataclass
-from typing import Optional
import aiohttp
@@ -26,7 +25,7 @@ class FearGreedCollector(BaseCollector):
async def is_available(self) -> bool:
return True
- async def _fetch_index(self) -> Optional[dict]:
+ async def _fetch_index(self) -> dict | None:
headers = {"User-Agent": "Mozilla/5.0"}
try:
async with aiohttp.ClientSession() as session:
@@ -50,7 +49,7 @@ class FearGreedCollector(BaseCollector):
return "Greed"
return "Extreme Greed"
- async def collect(self) -> Optional[FearGreedResult]:
+ async def collect(self) -> FearGreedResult | None:
data = await self._fetch_index()
if data is None:
return None
diff --git a/services/news-collector/src/news_collector/collectors/fed.py b/services/news-collector/src/news_collector/collectors/fed.py
index fce4842..52128e5 100644
--- a/services/news-collector/src/news_collector/collectors/fed.py
+++ b/services/news-collector/src/news_collector/collectors/fed.py
@@ -3,7 +3,7 @@
import asyncio
import logging
from calendar import timegm
-from datetime import datetime, timezone
+from datetime import UTC, datetime
import feedparser
from nltk.sentiment.vader import SentimentIntensityAnalyzer
@@ -76,10 +76,10 @@ class FedCollector(BaseCollector):
if published_parsed:
try:
ts = timegm(published_parsed)
- return datetime.fromtimestamp(ts, tz=timezone.utc)
+ return datetime.fromtimestamp(ts, tz=UTC)
except Exception:
pass
- return datetime.now(timezone.utc)
+ return datetime.now(UTC)
async def collect(self) -> list[NewsItem]:
try:
diff --git a/services/news-collector/src/news_collector/collectors/finnhub.py b/services/news-collector/src/news_collector/collectors/finnhub.py
index 13e3602..67cb455 100644
--- a/services/news-collector/src/news_collector/collectors/finnhub.py
+++ b/services/news-collector/src/news_collector/collectors/finnhub.py
@@ -1,7 +1,7 @@
"""Finnhub news collector with VADER sentiment analysis."""
import logging
-from datetime import datetime, timezone
+from datetime import UTC, datetime
import aiohttp
from nltk.sentiment.vader import SentimentIntensityAnalyzer
@@ -64,7 +64,7 @@ class FinnhubCollector(BaseCollector):
sentiment = sentiment_scores["compound"]
ts = article.get("datetime", 0)
- published_at = datetime.fromtimestamp(ts, tz=timezone.utc)
+ published_at = datetime.fromtimestamp(ts, tz=UTC)
related = article.get("related", "")
symbols = [t.strip() for t in related.split(",") if t.strip()] if related else []
diff --git a/services/news-collector/src/news_collector/collectors/reddit.py b/services/news-collector/src/news_collector/collectors/reddit.py
index 226a2f9..4e9d6f5 100644
--- a/services/news-collector/src/news_collector/collectors/reddit.py
+++ b/services/news-collector/src/news_collector/collectors/reddit.py
@@ -2,7 +2,7 @@
import logging
import re
-from datetime import datetime, timezone
+from datetime import UTC, datetime
import aiohttp
from nltk.sentiment.vader import SentimentIntensityAnalyzer
@@ -78,7 +78,7 @@ class RedditCollector(BaseCollector):
symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined)))
created_utc = post_data.get("created_utc", 0)
- published_at = datetime.fromtimestamp(created_utc, tz=timezone.utc)
+ published_at = datetime.fromtimestamp(created_utc, tz=UTC)
items.append(
NewsItem(
diff --git a/services/news-collector/src/news_collector/collectors/rss.py b/services/news-collector/src/news_collector/collectors/rss.py
index ddf8503..bca0e9f 100644
--- a/services/news-collector/src/news_collector/collectors/rss.py
+++ b/services/news-collector/src/news_collector/collectors/rss.py
@@ -3,7 +3,7 @@
import asyncio
import logging
import re
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from time import mktime
import feedparser
@@ -56,10 +56,10 @@ class RSSCollector(BaseCollector):
if parsed_time:
try:
ts = mktime(parsed_time)
- return datetime.fromtimestamp(ts, tz=timezone.utc)
+ return datetime.fromtimestamp(ts, tz=UTC)
except Exception:
pass
- return datetime.now(timezone.utc)
+ return datetime.now(UTC)
async def collect(self) -> list[NewsItem]:
try:
diff --git a/services/news-collector/src/news_collector/collectors/sec_edgar.py b/services/news-collector/src/news_collector/collectors/sec_edgar.py
index ca1d070..d88518f 100644
--- a/services/news-collector/src/news_collector/collectors/sec_edgar.py
+++ b/services/news-collector/src/news_collector/collectors/sec_edgar.py
@@ -1,13 +1,13 @@
"""SEC EDGAR filing collector (free, no API key required)."""
import logging
-from datetime import datetime, timezone
+from datetime import UTC, datetime
import aiohttp
from nltk.sentiment.vader import SentimentIntensityAnalyzer
-from shared.models import NewsCategory, NewsItem
from news_collector.collectors.base import BaseCollector
+from shared.models import NewsCategory, NewsItem
logger = logging.getLogger(__name__)
@@ -58,7 +58,7 @@ class SecEdgarCollector(BaseCollector):
async def collect(self) -> list[NewsItem]:
filings_data = await self._fetch_recent_filings()
items = []
- today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
+ today = datetime.now(UTC).strftime("%Y-%m-%d")
for company_data in filings_data:
tickers = [t["ticker"] for t in company_data.get("tickers", [])]
@@ -87,9 +87,7 @@ class SecEdgarCollector(BaseCollector):
headline=headline,
summary=desc,
url=f"https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany&accession={accession}",
- published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace(
- tzinfo=timezone.utc
- ),
+ published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace(tzinfo=UTC),
symbols=tickers,
sentiment=self._vader.polarity_scores(headline)["compound"],
category=NewsCategory.FILING,
diff --git a/services/news-collector/src/news_collector/collectors/truth_social.py b/services/news-collector/src/news_collector/collectors/truth_social.py
index 33ebc86..e2acd88 100644
--- a/services/news-collector/src/news_collector/collectors/truth_social.py
+++ b/services/news-collector/src/news_collector/collectors/truth_social.py
@@ -2,7 +2,7 @@
import logging
import re
-from datetime import datetime, timezone
+from datetime import UTC, datetime
import aiohttp
from nltk.sentiment.vader import SentimentIntensityAnalyzer
@@ -67,7 +67,7 @@ class TruthSocialCollector(BaseCollector):
try:
published_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
except Exception:
- published_at = datetime.now(timezone.utc)
+ published_at = datetime.now(UTC)
items.append(
NewsItem(
diff --git a/services/news-collector/src/news_collector/config.py b/services/news-collector/src/news_collector/config.py
index 70d98f1..6e78eba 100644
--- a/services/news-collector/src/news_collector/config.py
+++ b/services/news-collector/src/news_collector/config.py
@@ -5,6 +5,3 @@ from shared.config import Settings
class NewsCollectorConfig(Settings):
health_port: int = 8084
- finnhub_api_key: str = ""
- news_poll_interval: int = 300
- sentiment_aggregate_interval: int = 900
diff --git a/services/news-collector/src/news_collector/main.py b/services/news-collector/src/news_collector/main.py
index 3493f7c..c39fa67 100644
--- a/services/news-collector/src/news_collector/main.py
+++ b/services/news-collector/src/news_collector/main.py
@@ -1,8 +1,18 @@
"""News Collector Service — fetches news from multiple sources and aggregates sentiment."""
import asyncio
-from datetime import datetime, timezone
+from datetime import UTC, datetime
+import aiohttp
+
+from news_collector.collectors.fear_greed import FearGreedCollector
+from news_collector.collectors.fed import FedCollector
+from news_collector.collectors.finnhub import FinnhubCollector
+from news_collector.collectors.reddit import RedditCollector
+from news_collector.collectors.rss import RSSCollector
+from news_collector.collectors.sec_edgar import SecEdgarCollector
+from news_collector.collectors.truth_social import TruthSocialCollector
+from news_collector.config import NewsCollectorConfig
from shared.broker import RedisBroker
from shared.db import Database
from shared.events import NewsEvent
@@ -11,20 +21,9 @@ from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.models import NewsItem
from shared.notifier import TelegramNotifier
-from shared.sentiment_models import MarketSentiment
from shared.sentiment import SentimentAggregator
-
-from news_collector.config import NewsCollectorConfig
-from news_collector.collectors.finnhub import FinnhubCollector
-from news_collector.collectors.rss import RSSCollector
-from news_collector.collectors.sec_edgar import SecEdgarCollector
-from news_collector.collectors.truth_social import TruthSocialCollector
-from news_collector.collectors.reddit import RedditCollector
-from news_collector.collectors.fear_greed import FearGreedCollector
-from news_collector.collectors.fed import FedCollector
-
-# Health check port: base + 4
-HEALTH_PORT_OFFSET = 4
+from shared.sentiment_models import MarketSentiment
+from shared.shutdown import GracefulShutdown
async def run_collector_once(collector, db: Database, broker: RedisBroker) -> int:
@@ -53,9 +52,15 @@ async def run_collector_loop(collector, db: Database, broker: RedisBroker, log)
collector=collector.name,
count=count,
)
- except Exception as exc:
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
log.warning(
- "collector_error",
+ "collector_network_error",
+ collector=collector.name,
+ error=str(exc),
+ )
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning(
+ "collector_parse_error",
collector=collector.name,
error=str(exc),
)
@@ -74,7 +79,7 @@ async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log)
vix=None,
fed_stance="neutral",
market_regime=_determine_regime(result.fear_greed, None),
- updated_at=datetime.now(timezone.utc),
+ updated_at=datetime.now(UTC),
)
await db.upsert_market_sentiment(ms)
log.info(
@@ -82,8 +87,10 @@ async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log)
value=result.fear_greed,
label=result.fear_greed_label,
)
- except Exception as exc:
- log.warning("fear_greed_error", error=str(exc))
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("fear_greed_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("fear_greed_parse_error", error=str(exc))
await asyncio.sleep(collector.poll_interval)
@@ -93,14 +100,16 @@ async def run_aggregator_loop(db: Database, interval: int, log) -> None:
while True:
await asyncio.sleep(interval)
try:
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
news_items = await db.get_recent_news(hours=24)
scores = aggregator.aggregate(news_items, now)
for score in scores.values():
await db.upsert_symbol_score(score)
log.info("aggregation_complete", symbols=len(scores))
- except Exception as exc:
- log.warning("aggregator_error", error=str(exc))
+ except (ConnectionError, TimeoutError) as exc:
+ log.warning("aggregator_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("aggregator_parse_error", error=str(exc))
def _determine_regime(fear_greed: int, vix: float | None) -> str:
@@ -115,14 +124,14 @@ async def run() -> None:
metrics = ServiceMetrics("news_collector")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
health = HealthCheckServer(
"news-collector",
@@ -133,7 +142,7 @@ async def run() -> None:
metrics.service_up.labels(service="news-collector").set(1)
# Build collectors
- finnhub = FinnhubCollector(api_key=config.finnhub_api_key)
+ finnhub = FinnhubCollector(api_key=config.finnhub_api_key.get_secret_value())
rss = RSSCollector()
sec = SecEdgarCollector()
truth = TruthSocialCollector()
@@ -143,6 +152,9 @@ async def run() -> None:
news_collectors = [finnhub, rss, sec, truth, reddit, fed]
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info(
"starting",
collectors=[c.name for c in news_collectors],
@@ -151,14 +163,13 @@ async def run() -> None:
)
try:
- tasks = []
- for collector in news_collectors:
- tasks.append(
- asyncio.create_task(
- run_collector_loop(collector, db, broker, log),
- name=f"collector-{collector.name}",
- )
+ tasks = [
+ asyncio.create_task(
+ run_collector_loop(collector, db, broker, log),
+ name=f"collector-{collector.name}",
)
+ for collector in news_collectors
+ ]
tasks.append(
asyncio.create_task(
run_fear_greed_loop(fear_greed, db, log),
@@ -171,9 +182,9 @@ async def run() -> None:
name="aggregator-loop",
)
)
- await asyncio.gather(*tasks)
+ await shutdown.wait()
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "news-collector")
raise
finally:
diff --git a/services/news-collector/tests/test_fear_greed.py b/services/news-collector/tests/test_fear_greed.py
index d483aa6..e8bd8f0 100644
--- a/services/news-collector/tests/test_fear_greed.py
+++ b/services/news-collector/tests/test_fear_greed.py
@@ -1,8 +1,8 @@
"""Tests for CNN Fear & Greed Index collector."""
-import pytest
from unittest.mock import AsyncMock, patch
+import pytest
from news_collector.collectors.fear_greed import FearGreedCollector
diff --git a/services/news-collector/tests/test_fed.py b/services/news-collector/tests/test_fed.py
index d1a736b..7f1c46c 100644
--- a/services/news-collector/tests/test_fed.py
+++ b/services/news-collector/tests/test_fed.py
@@ -1,7 +1,8 @@
"""Tests for Federal Reserve collector."""
-import pytest
from unittest.mock import AsyncMock, patch
+
+import pytest
from news_collector.collectors.fed import FedCollector
diff --git a/services/news-collector/tests/test_finnhub.py b/services/news-collector/tests/test_finnhub.py
index a4cf169..3af65b8 100644
--- a/services/news-collector/tests/test_finnhub.py
+++ b/services/news-collector/tests/test_finnhub.py
@@ -1,8 +1,8 @@
"""Tests for Finnhub news collector."""
-import pytest
from unittest.mock import AsyncMock, patch
+import pytest
from news_collector.collectors.finnhub import FinnhubCollector
diff --git a/services/news-collector/tests/test_main.py b/services/news-collector/tests/test_main.py
index 66190dc..f85569a 100644
--- a/services/news-collector/tests/test_main.py
+++ b/services/news-collector/tests/test_main.py
@@ -1,16 +1,18 @@
"""Tests for news collector scheduler."""
+from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
-from datetime import datetime, timezone
-from shared.models import NewsCategory, NewsItem
+
from news_collector.main import run_collector_once
+from shared.models import NewsCategory, NewsItem
+
async def test_run_collector_once_stores_and_publishes():
mock_item = NewsItem(
source="test",
headline="Test news",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.5,
category=NewsCategory.MACRO,
)
diff --git a/services/news-collector/tests/test_reddit.py b/services/news-collector/tests/test_reddit.py
index 440b173..31b1dc1 100644
--- a/services/news-collector/tests/test_reddit.py
+++ b/services/news-collector/tests/test_reddit.py
@@ -1,7 +1,8 @@
"""Tests for Reddit collector."""
-import pytest
from unittest.mock import AsyncMock, patch
+
+import pytest
from news_collector.collectors.reddit import RedditCollector
diff --git a/services/news-collector/tests/test_rss.py b/services/news-collector/tests/test_rss.py
index e03250a..7242c75 100644
--- a/services/news-collector/tests/test_rss.py
+++ b/services/news-collector/tests/test_rss.py
@@ -1,8 +1,8 @@
"""Tests for RSS news collector."""
-import pytest
from unittest.mock import AsyncMock, patch
+import pytest
from news_collector.collectors.rss import RSSCollector
diff --git a/services/news-collector/tests/test_sec_edgar.py b/services/news-collector/tests/test_sec_edgar.py
index 5d4f69f..b0faf18 100644
--- a/services/news-collector/tests/test_sec_edgar.py
+++ b/services/news-collector/tests/test_sec_edgar.py
@@ -1,9 +1,9 @@
"""Tests for SEC EDGAR filing collector."""
-import pytest
-from datetime import datetime, timezone
-from unittest.mock import AsyncMock, patch, MagicMock
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock, patch
+import pytest
from news_collector.collectors.sec_edgar import SecEdgarCollector
@@ -37,7 +37,7 @@ async def test_collect_parses_filings(collector):
}
mock_datetime = MagicMock(spec=datetime)
- mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=timezone.utc)
+ mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=UTC)
mock_datetime.strptime = datetime.strptime
with patch.object(
diff --git a/services/news-collector/tests/test_truth_social.py b/services/news-collector/tests/test_truth_social.py
index 91ddb9d..52f1e46 100644
--- a/services/news-collector/tests/test_truth_social.py
+++ b/services/news-collector/tests/test_truth_social.py
@@ -1,7 +1,8 @@
"""Tests for Truth Social collector."""
-import pytest
from unittest.mock import AsyncMock, patch
+
+import pytest
from news_collector.collectors.truth_social import TruthSocialCollector
diff --git a/services/order-executor/Dockerfile b/services/order-executor/Dockerfile
index bc8b21c..376afec 100644
--- a/services/order-executor/Dockerfile
+++ b/services/order-executor/Dockerfile
@@ -1,8 +1,15 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/order-executor/ services/order-executor/
RUN pip install --no-cache-dir ./services/order-executor
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "order_executor.main"]
diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py
index a71e762..fd502cd 100644
--- a/services/order-executor/src/order_executor/executor.py
+++ b/services/order-executor/src/order_executor/executor.py
@@ -1,18 +1,18 @@
"""Order execution logic."""
-import structlog
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from typing import Any, Optional
+from typing import Any
+
+import structlog
+from order_executor.risk_manager import RiskManager
from shared.broker import RedisBroker
from shared.db import Database
from shared.events import OrderEvent
from shared.models import Order, OrderStatus, OrderType, Signal
from shared.notifier import TelegramNotifier
-from order_executor.risk_manager import RiskManager
-
logger = structlog.get_logger()
@@ -35,7 +35,7 @@ class OrderExecutor:
self.notifier = notifier
self.dry_run = dry_run
- async def execute(self, signal: Signal) -> Optional[Order]:
+ async def execute(self, signal: Signal) -> Order | None:
"""Run risk checks and place an order for the given signal."""
# Fetch buying power from Alpaca
balance = await self.exchange.get_buying_power()
@@ -71,7 +71,7 @@ class OrderExecutor:
if self.dry_run:
order.status = OrderStatus.FILLED
- order.filled_at = datetime.now(timezone.utc)
+ order.filled_at = datetime.now(UTC)
logger.info(
"order_filled_dry_run",
side=str(order.side),
@@ -87,7 +87,7 @@ class OrderExecutor:
type="market",
)
order.status = OrderStatus.FILLED
- order.filled_at = datetime.now(timezone.utc)
+ order.filled_at = datetime.now(UTC)
logger.info(
"order_filled",
side=str(order.side),
diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py
index 51ab286..99f88e1 100644
--- a/services/order-executor/src/order_executor/main.py
+++ b/services/order-executor/src/order_executor/main.py
@@ -3,6 +3,11 @@
import asyncio
from decimal import Decimal
+import aiohttp
+
+from order_executor.config import ExecutorConfig
+from order_executor.executor import OrderExecutor
+from order_executor.risk_manager import RiskManager
from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
from shared.db import Database
@@ -11,10 +16,7 @@ from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
-
-from order_executor.config import ExecutorConfig
-from order_executor.executor import OrderExecutor
-from order_executor.risk_manager import RiskManager
+from shared.shutdown import GracefulShutdown
# Health check port: base + 2
HEALTH_PORT_OFFSET = 2
@@ -26,18 +28,18 @@ async def run() -> None:
metrics = ServiceMetrics("order_executor")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
alpaca = AlpacaClient(
- api_key=config.alpaca_api_key,
- api_secret=config.alpaca_api_secret,
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
paper=config.alpaca_paper,
)
@@ -83,6 +85,9 @@ async def run() -> None:
await broker.ensure_group(stream, GROUP)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("started", stream=stream, dry_run=config.dry_run)
try:
@@ -94,10 +99,15 @@ async def run() -> None:
if event.type == EventType.SIGNAL:
await executor.execute(event.data)
await broker.ack(stream, GROUP, msg_id)
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("pending_network_error", error=str(exc), msg_id=msg_id)
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("pending_parse_error", error=str(exc), msg_id=msg_id)
+ await broker.ack(stream, GROUP, msg_id)
except Exception as exc:
- log.error("pending_failed", error=str(exc), msg_id=msg_id)
+ log.error("pending_failed", error=str(exc), msg_id=msg_id, exc_info=True)
- while True:
+ while not shutdown.is_shutting_down:
messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000)
for msg_id, msg in messages:
try:
@@ -110,8 +120,19 @@ async def run() -> None:
service="order-executor", event_type="signal"
).inc()
await broker.ack(stream, GROUP, msg_id)
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("process_network_error", error=str(exc))
+ metrics.errors_total.labels(
+ service="order-executor", error_type="network"
+ ).inc()
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("process_parse_error", error=str(exc))
+ await broker.ack(stream, GROUP, msg_id)
+ metrics.errors_total.labels(
+ service="order-executor", error_type="validation"
+ ).inc()
except Exception as exc:
- log.error("process_failed", error=str(exc))
+ log.error("process_failed", error=str(exc), exc_info=True)
metrics.errors_total.labels(
service="order-executor", error_type="processing"
).inc()
diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py
index 5a05746..811a862 100644
--- a/services/order-executor/src/order_executor/risk_manager.py
+++ b/services/order-executor/src/order_executor/risk_manager.py
@@ -1,12 +1,12 @@
"""Risk management for order execution."""
+import math
+from collections import deque
from dataclasses import dataclass
-from datetime import datetime, timezone, timedelta
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
-from collections import deque
-import math
-from shared.models import Signal, OrderSide, Position
+from shared.models import OrderSide, Position, Signal
@dataclass
@@ -123,15 +123,13 @@ class RiskManager:
else:
self._consecutive_losses += 1
if self._consecutive_losses >= self._max_consecutive_losses:
- self._paused_until = datetime.now(timezone.utc) + timedelta(
- minutes=self._loss_pause_minutes
- )
+ self._paused_until = datetime.now(UTC) + timedelta(minutes=self._loss_pause_minutes)
def is_paused(self) -> bool:
"""Check if trading is paused due to consecutive losses."""
if self._paused_until is None:
return False
- if datetime.now(timezone.utc) >= self._paused_until:
+ if datetime.now(UTC) >= self._paused_until:
self._paused_until = None
self._consecutive_losses = 0
return False
@@ -233,9 +231,9 @@ class RiskManager:
mean_a = sum(returns_a) / len(returns_a)
mean_b = sum(returns_b) / len(returns_b)
- cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len(
- returns_a
- )
+ cov = sum(
+ (a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b, strict=True)
+ ) / len(returns_a)
std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a))
std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b))
@@ -280,7 +278,11 @@ class RiskManager:
min_len = min(len(r) for r in all_returns)
portfolio_returns = []
for i in range(min_len):
- pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i)
+ pr = sum(
+ w * r[-(min_len - i)]
+ for w, r in zip(weights, all_returns, strict=False)
+ if len(r) > i
+ )
portfolio_returns.append(pr)
if not portfolio_returns:
diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py
index dd823d7..cda6b72 100644
--- a/services/order-executor/tests/test_executor.py
+++ b/services/order-executor/tests/test_executor.py
@@ -4,11 +4,11 @@ from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import pytest
-
-from shared.models import OrderSide, OrderStatus, Signal
from order_executor.executor import OrderExecutor
from order_executor.risk_manager import RiskCheckResult, RiskManager
+from shared.models import OrderSide, OrderStatus, Signal
+
def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal:
return Signal(
diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py
index 3d5175b..66e769c 100644
--- a/services/order-executor/tests/test_risk_manager.py
+++ b/services/order-executor/tests/test_risk_manager.py
@@ -2,9 +2,9 @@
from decimal import Decimal
+from order_executor.risk_manager import RiskManager
from shared.models import OrderSide, Position, Signal
-from order_executor.risk_manager import RiskManager
def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal:
diff --git a/services/portfolio-manager/Dockerfile b/services/portfolio-manager/Dockerfile
index b1a7681..0fa3f35 100644
--- a/services/portfolio-manager/Dockerfile
+++ b/services/portfolio-manager/Dockerfile
@@ -1,8 +1,15 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/portfolio-manager/ services/portfolio-manager/
RUN pip install --no-cache-dir ./services/portfolio-manager
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "portfolio_manager.main"]
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py
index a6823ae..f885aa8 100644
--- a/services/portfolio-manager/src/portfolio_manager/main.py
+++ b/services/portfolio-manager/src/portfolio_manager/main.py
@@ -2,6 +2,10 @@
import asyncio
+import sqlalchemy.exc
+
+from portfolio_manager.config import PortfolioConfig
+from portfolio_manager.portfolio import PortfolioTracker
from shared.broker import RedisBroker
from shared.db import Database
from shared.events import Event, OrderEvent
@@ -9,9 +13,7 @@ from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
-
-from portfolio_manager.config import PortfolioConfig
-from portfolio_manager.portfolio import PortfolioTracker
+from shared.shutdown import GracefulShutdown
ORDERS_STREAM = "orders"
@@ -51,8 +53,12 @@ async def snapshot_loop(
while True:
try:
await save_snapshot(db, tracker, notifier, log)
+ except (sqlalchemy.exc.OperationalError, ConnectionError, TimeoutError) as exc:
+ log.warning("snapshot_db_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("snapshot_data_error", error=str(exc))
except Exception as exc:
- log.error("snapshot_failed", error=str(exc))
+ log.error("snapshot_failed", error=str(exc), exc_info=True)
await asyncio.sleep(interval_hours * 3600)
@@ -61,10 +67,10 @@ async def run() -> None:
log = setup_logging("portfolio-manager", config.log_level, config.log_format)
metrics = ServiceMetrics("portfolio_manager")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id
+ bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id
)
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
tracker = PortfolioTracker()
health = HealthCheckServer(
@@ -76,13 +82,16 @@ async def run() -> None:
await health.start()
metrics.service_up.labels(service="portfolio-manager").set(1)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
snapshot_task = asyncio.create_task(
snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log)
)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
GROUP = "portfolio-manager"
CONSUMER = "portfolio-1"
log.info("service_started", stream=ORDERS_STREAM)
@@ -108,12 +117,16 @@ async def run() -> None:
service="portfolio-manager", event_type="order"
).inc()
await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("pending_parse_error", error=str(exc), msg_id=msg_id)
+ await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ metrics.errors_total.labels(service="portfolio-manager", error_type="validation").inc()
except Exception as exc:
- log.error("pending_process_failed", error=str(exc), msg_id=msg_id)
+ log.error("pending_process_failed", error=str(exc), msg_id=msg_id, exc_info=True)
metrics.errors_total.labels(service="portfolio-manager", error_type="processing").inc()
try:
- while True:
+ while not shutdown.is_shutting_down:
messages = await broker.read_group(ORDERS_STREAM, GROUP, CONSUMER, count=10, block=1000)
for msg_id, msg in messages:
try:
@@ -134,13 +147,21 @@ async def run() -> None:
service="portfolio-manager", event_type="order"
).inc()
await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("message_parse_error", error=str(exc), msg_id=msg_id)
+ await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ metrics.errors_total.labels(
+ service="portfolio-manager", error_type="validation"
+ ).inc()
except Exception as exc:
- log.exception("message_processing_failed", error=str(exc), msg_id=msg_id)
+ log.error(
+ "message_processing_failed", error=str(exc), msg_id=msg_id, exc_info=True
+ )
metrics.errors_total.labels(
service="portfolio-manager", error_type="processing"
).inc()
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "portfolio-manager")
raise
finally:
diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py
index 365dc1a..c8a6894 100644
--- a/services/portfolio-manager/tests/test_portfolio.py
+++ b/services/portfolio-manager/tests/test_portfolio.py
@@ -2,9 +2,10 @@
from decimal import Decimal
-from shared.models import Order, OrderSide, OrderStatus, OrderType
from portfolio_manager.portfolio import PortfolioTracker
+from shared.models import Order, OrderSide, OrderStatus, OrderType
+
def make_order(side: OrderSide, price: str, quantity: str) -> Order:
"""Helper to create a filled Order."""
diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py
index ec5e92d..f2026e2 100644
--- a/services/portfolio-manager/tests/test_snapshot.py
+++ b/services/portfolio-manager/tests/test_snapshot.py
@@ -1,9 +1,10 @@
"""Tests for save_snapshot in portfolio-manager."""
-import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
+import pytest
+
from shared.models import Position
diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile
index de635dc..f1484e9 100644
--- a/services/strategy-engine/Dockerfile
+++ b/services/strategy-engine/Dockerfile
@@ -1,9 +1,16 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/strategy-engine/ services/strategy-engine/
RUN pip install --no-cache-dir ./services/strategy-engine
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
COPY services/strategy-engine/strategies/ /app/strategies/
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "strategy_engine.main"]
diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml
index 4f5b6be..e4bfb12 100644
--- a/services/strategy-engine/pyproject.toml
+++ b/services/strategy-engine/pyproject.toml
@@ -3,11 +3,7 @@ name = "strategy-engine"
version = "0.1.0"
description = "Plugin-based strategy execution engine"
requires-python = ">=3.12"
-dependencies = [
- "pandas>=2.0",
- "numpy>=1.20",
- "trading-shared",
-]
+dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py
index 2a9cb43..9fd9c49 100644
--- a/services/strategy-engine/src/strategy_engine/config.py
+++ b/services/strategy-engine/src/strategy_engine/config.py
@@ -7,9 +7,3 @@ class StrategyConfig(Settings):
symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]
timeframes: list[str] = ["1m"]
strategy_params: dict = {}
- selector_candidates_time: str = "15:00"
- selector_filter_time: str = "15:15"
- selector_final_time: str = "15:30"
- selector_max_picks: int = 3
- anthropic_api_key: str = ""
- anthropic_model: str = "claude-sonnet-4-20250514"
diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py
index d401aee..4b2c468 100644
--- a/services/strategy-engine/src/strategy_engine/engine.py
+++ b/services/strategy-engine/src/strategy_engine/engine.py
@@ -2,11 +2,11 @@
import logging
-from shared.broker import RedisBroker
-from shared.events import CandleEvent, SignalEvent, Event
-
from strategies.base import BaseStrategy
+from shared.broker import RedisBroker
+from shared.events import CandleEvent, Event, SignalEvent
+
logger = logging.getLogger(__name__)
@@ -26,7 +26,7 @@ class StrategyEngine:
try:
event = Event.from_dict(raw)
except Exception as exc:
- logger.warning("Failed to parse event: %s – %s", raw, exc)
+ logger.warning("Failed to parse event: %s - %s", raw, exc)
continue
if not isinstance(event, CandleEvent):
diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py
index 5a30766..3d73058 100644
--- a/services/strategy-engine/src/strategy_engine/main.py
+++ b/services/strategy-engine/src/strategy_engine/main.py
@@ -1,9 +1,11 @@
"""Strategy Engine Service entry point."""
import asyncio
+import zoneinfo
from datetime import datetime
from pathlib import Path
-import zoneinfo
+
+import aiohttp
from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
@@ -13,7 +15,7 @@ from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
from shared.sentiment_models import MarketSentiment
-
+from shared.shutdown import GracefulShutdown
from strategy_engine.config import StrategyConfig
from strategy_engine.engine import StrategyEngine
from strategy_engine.plugin_loader import load_strategies
@@ -63,8 +65,12 @@ async def run_stock_selector(
log.info("stock_selector_complete", picks=[s.symbol for s in selections])
else:
log.info("stock_selector_no_picks")
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("stock_selector_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("stock_selector_data_error", error=str(exc))
except Exception as exc:
- log.error("stock_selector_error", error=str(exc))
+ log.error("stock_selector_error", error=str(exc), exc_info=True)
await asyncio.sleep(120) # Sleep past this minute
else:
await asyncio.sleep(30)
@@ -76,18 +82,18 @@ async def run() -> None:
metrics = ServiceMetrics("strategy_engine")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
alpaca = AlpacaClient(
- api_key=config.alpaca_api_key,
- api_secret=config.alpaca_api_secret,
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
paper=config.alpaca_paper,
)
@@ -97,6 +103,9 @@ async def run() -> None:
params = config.strategy_params.get(strategy.name, {})
strategy.configure(params)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies])
engine = StrategyEngine(broker=broker, strategies=strategies)
@@ -117,12 +126,12 @@ async def run() -> None:
task = asyncio.create_task(process_symbol(engine, stream, log))
tasks.append(task)
- if config.anthropic_api_key:
+ if config.anthropic_api_key.get_secret_value():
selector = StockSelector(
db=db,
broker=broker,
alpaca=alpaca,
- anthropic_api_key=config.anthropic_api_key,
+ anthropic_api_key=config.anthropic_api_key.get_secret_value(),
anthropic_model=config.anthropic_model,
max_picks=config.selector_max_picks,
)
@@ -131,9 +140,9 @@ async def run() -> None:
)
log.info("stock_selector_enabled", time=config.selector_final_time)
- await asyncio.gather(*tasks)
+ await shutdown.wait()
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "strategy-engine")
raise
finally:
diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py
index 62e4160..57680db 100644
--- a/services/strategy-engine/src/strategy_engine/plugin_loader.py
+++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py
@@ -5,7 +5,6 @@ import sys
from pathlib import Path
import yaml
-
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py
index 268d557..8657b93 100644
--- a/services/strategy-engine/src/strategy_engine/stock_selector.py
+++ b/services/strategy-engine/src/strategy_engine/stock_selector.py
@@ -1,9 +1,10 @@
"""3-stage stock selector engine: sentiment → technical → LLM."""
+import asyncio
import json
import logging
import re
-from datetime import datetime, timezone
+from datetime import UTC, datetime
import aiohttp
@@ -18,18 +19,12 @@ logger = logging.getLogger(__name__)
ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
-def _parse_llm_selections(text: str) -> list[SelectedStock]:
- """Parse LLM response into SelectedStock list.
-
- Handles both bare JSON arrays and markdown code blocks.
- Returns empty list on any parse error.
- """
- # Try to extract JSON from markdown code block first
+def _extract_json_array(text: str) -> list[dict] | None:
+ """Extract a JSON array from text that may contain markdown code blocks."""
code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
if code_block:
raw = code_block.group(1)
else:
- # Try to find a bare JSON array
array_match = re.search(r"\[.*\]", text, re.DOTALL)
if array_match:
raw = array_match.group(0)
@@ -38,27 +33,38 @@ def _parse_llm_selections(text: str) -> list[SelectedStock]:
try:
data = json.loads(raw)
- if not isinstance(data, list):
- return []
- selections = []
- for item in data:
- if not isinstance(item, dict):
- continue
- try:
- selection = SelectedStock(
- symbol=item["symbol"],
- side=OrderSide(item["side"]),
- conviction=float(item["conviction"]),
- reason=item.get("reason", ""),
- key_news=item.get("key_news", []),
- )
- selections.append(selection)
- except (KeyError, ValueError) as e:
- logger.warning("Skipping invalid selection item: %s", e)
- return selections
+ if isinstance(data, list):
+ return [item for item in data if isinstance(item, dict)]
+ return None
except (json.JSONDecodeError, TypeError):
+ return None
+
+
+def _parse_llm_selections(text: str) -> list[SelectedStock]:
+ """Parse LLM response into SelectedStock list.
+
+ Handles both bare JSON arrays and markdown code blocks.
+ Returns empty list on any parse error.
+ """
+ items = _extract_json_array(text)
+ if items is None:
return []
+ selections = []
+ for item in items:
+ try:
+ selection = SelectedStock(
+ symbol=item["symbol"],
+ side=OrderSide(item["side"]),
+ conviction=float(item["conviction"]),
+ reason=item.get("reason", ""),
+ key_news=item.get("key_news", []),
+ )
+ selections.append(selection)
+ except (KeyError, ValueError) as e:
+ logger.warning("Skipping invalid selection item: %s", e)
+ return selections
+
class SentimentCandidateSource:
"""Generates candidates from DB sentiment scores."""
@@ -92,7 +98,7 @@ class LLMCandidateSource:
self._api_key = api_key
self._model = model
- async def get_candidates(self) -> list[Candidate]:
+ async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]:
news_items = await self._db.get_recent_news(hours=24)
if not news_items:
return []
@@ -110,26 +116,29 @@ class LLMCandidateSource:
"Headlines:\n" + "\n".join(headlines)
)
+ own_session = session is None
+ if own_session:
+ session = aiohttp.ClientSession()
+
try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- ANTHROPIC_API_URL,
- headers={
- "x-api-key": self._api_key,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- json={
- "model": self._model,
- "max_tokens": 1024,
- "messages": [{"role": "user", "content": prompt}],
- },
- ) as resp:
- if resp.status != 200:
- body = await resp.text()
- logger.error("LLM candidate source error %d: %s", resp.status, body)
- return []
- data = await resp.json()
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM candidate source error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
content = data.get("content", [])
text = ""
@@ -141,40 +150,32 @@ class LLMCandidateSource:
except Exception as e:
logger.error("LLMCandidateSource error: %s", e)
return []
+ finally:
+ if own_session:
+ await session.close()
def _parse_candidates(self, text: str) -> list[Candidate]:
- code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
- if code_block:
- raw = code_block.group(1)
- else:
- array_match = re.search(r"\[.*\]", text, re.DOTALL)
- raw = array_match.group(0) if array_match else text.strip()
+ items = _extract_json_array(text)
+ if items is None:
+ return []
- try:
- items = json.loads(raw)
- if not isinstance(items, list):
- return []
- candidates = []
- for item in items:
- if not isinstance(item, dict):
- continue
- try:
- direction_str = item.get("direction", "BUY")
- direction = OrderSide(direction_str)
- except ValueError:
- direction = None
- candidates.append(
- Candidate(
- symbol=item["symbol"],
- source="llm",
- direction=direction,
- score=float(item.get("score", 0.5)),
- reason=item.get("reason", ""),
- )
+ candidates = []
+ for item in items:
+ try:
+ direction_str = item.get("direction", "BUY")
+ direction = OrderSide(direction_str)
+ except ValueError:
+ direction = None
+ candidates.append(
+ Candidate(
+ symbol=item["symbol"],
+ source="llm",
+ direction=direction,
+ score=float(item.get("score", 0.5)),
+ reason=item.get("reason", ""),
)
- return candidates
- except (json.JSONDecodeError, TypeError, KeyError):
- return []
+ )
+ return candidates
def _compute_rsi(closes: list[float], period: int = 14) -> float:
@@ -217,6 +218,18 @@ class StockSelector:
self._api_key = anthropic_api_key
self._model = anthropic_model
self._max_picks = max_picks
+ self._http_session: aiohttp.ClientSession | None = None
+ self._session_lock = asyncio.Lock()
+
+ async def _ensure_session(self) -> aiohttp.ClientSession:
+ async with self._session_lock:
+ if self._http_session is None or self._http_session.closed:
+ self._http_session = aiohttp.ClientSession()
+ return self._http_session
+
+ async def close(self) -> None:
+ if self._http_session and not self._http_session.closed:
+ await self._http_session.close()
async def select(self) -> list[SelectedStock]:
"""Run the full 3-stage pipeline and return selected stocks."""
@@ -235,8 +248,9 @@ class StockSelector:
sentiment_source = SentimentCandidateSource(self._db)
llm_source = LLMCandidateSource(self._db, self._api_key, self._model)
+ session = await self._ensure_session()
sentiment_candidates = await sentiment_source.get_candidates()
- llm_candidates = await llm_source.get_candidates()
+ llm_candidates = await llm_source.get_candidates(session=session)
candidates = self._merge_candidates(sentiment_candidates, llm_candidates)
if not candidates:
@@ -253,7 +267,7 @@ class StockSelector:
selections = await self._llm_final_select(filtered, market_sentiment)
# Persist and publish
- today = datetime.now(timezone.utc).date()
+ today = datetime.now(UTC).date()
sentiment_snapshot = {
"fear_greed": market_sentiment.fear_greed,
"market_regime": market_sentiment.market_regime,
@@ -372,25 +386,25 @@ class StockSelector:
)
try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- ANTHROPIC_API_URL,
- headers={
- "x-api-key": self._api_key,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- json={
- "model": self._model,
- "max_tokens": 1024,
- "messages": [{"role": "user", "content": prompt}],
- },
- ) as resp:
- if resp.status != 200:
- body = await resp.text()
- logger.error("LLM final select error %d: %s", resp.status, body)
- return []
- data = await resp.json()
+ session = await self._ensure_session()
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM final select error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
content = data.get("content", [])
text = ""
diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py
index d5be675..1d9d289 100644
--- a/services/strategy-engine/strategies/base.py
+++ b/services/strategy-engine/strategies/base.py
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from collections import deque
from decimal import Decimal
-from typing import Optional
import pandas as pd
@@ -102,7 +101,7 @@ class BaseStrategy(ABC):
def _calculate_atr_stops(
self, entry_price: Decimal, side: str
- ) -> tuple[Optional[Decimal], Optional[Decimal]]:
+ ) -> tuple[Decimal | None, Decimal | None]:
"""Calculate ATR-based stop-loss and take-profit.
Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data.
@@ -131,7 +130,7 @@ class BaseStrategy(ABC):
return sl, tp
- def _apply_filters(self, signal: Signal) -> Optional[Signal]:
+ def _apply_filters(self, signal: Signal) -> Signal | None:
"""Apply all filters to a signal. Returns signal with SL/TP or None if filtered out."""
if signal is None:
return None
diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py
index ebe7967..02ff09a 100644
--- a/services/strategy-engine/strategies/bollinger_strategy.py
+++ b/services/strategy-engine/strategies/bollinger_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py
index ba92485..f562918 100644
--- a/services/strategy-engine/strategies/combined_strategy.py
+++ b/services/strategy-engine/strategies/combined_strategy.py
@@ -2,7 +2,7 @@
from decimal import Decimal
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py
index 68d0ba3..9c181f3 100644
--- a/services/strategy-engine/strategies/ema_crossover_strategy.py
+++ b/services/strategy-engine/strategies/ema_crossover_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py
index 283bfe5..491252e 100644
--- a/services/strategy-engine/strategies/grid_strategy.py
+++ b/services/strategy-engine/strategies/grid_strategy.py
@@ -1,9 +1,8 @@
from decimal import Decimal
-from typing import Optional
import numpy as np
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -17,7 +16,7 @@ class GridStrategy(BaseStrategy):
self._grid_count: int = 5
self._quantity: Decimal = Decimal("0.01")
self._grid_levels: list[float] = []
- self._last_zone: Optional[int] = None
+ self._last_zone: int | None = None
self._exit_threshold_pct: float = 5.0
self._out_of_range: bool = False
self._in_position: bool = False # Track if we have any grid positions
diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py
index 3c713e6..01637b7 100644
--- a/services/strategy-engine/strategies/indicators/__init__.py
+++ b/services/strategy-engine/strategies/indicators/__init__.py
@@ -1,21 +1,21 @@
"""Reusable technical indicator functions."""
-from strategies.indicators.trend import ema, sma, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
__all__ = [
- "ema",
- "sma",
- "macd",
"adx",
"atr",
"bollinger_bands",
+ "ema",
"keltner_channels",
+ "macd",
+ "obv",
"rsi",
+ "sma",
"stochastic",
- "volume_sma",
"volume_ratio",
- "obv",
+ "volume_sma",
]
diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py
index c479452..a82210b 100644
--- a/services/strategy-engine/strategies/indicators/momentum.py
+++ b/services/strategy-engine/strategies/indicators/momentum.py
@@ -1,7 +1,7 @@
"""Momentum indicators: RSI, Stochastic."""
-import pandas as pd
import numpy as np
+import pandas as pd
def rsi(closes: pd.Series, period: int = 14) -> pd.Series:
diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py
index c94a071..1085199 100644
--- a/services/strategy-engine/strategies/indicators/trend.py
+++ b/services/strategy-engine/strategies/indicators/trend.py
@@ -1,7 +1,7 @@
"""Trend indicators: EMA, SMA, MACD, ADX."""
-import pandas as pd
import numpy as np
+import pandas as pd
def sma(series: pd.Series, period: int) -> pd.Series:
diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py
index c16143e..da82f26 100644
--- a/services/strategy-engine/strategies/indicators/volatility.py
+++ b/services/strategy-engine/strategies/indicators/volatility.py
@@ -1,7 +1,7 @@
"""Volatility indicators: ATR, Bollinger Bands, Keltner Channels."""
-import pandas as pd
import numpy as np
+import pandas as pd
def atr(
diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py
index 502f1ce..d7c6471 100644
--- a/services/strategy-engine/strategies/indicators/volume.py
+++ b/services/strategy-engine/strategies/indicators/volume.py
@@ -1,7 +1,7 @@
"""Volume indicators: Volume SMA, Volume Ratio, OBV."""
-import pandas as pd
import numpy as np
+import pandas as pd
def volume_sma(volumes: pd.Series, period: int = 20) -> pd.Series:
diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py
index 356a42b..b5aea07 100644
--- a/services/strategy-engine/strategies/macd_strategy.py
+++ b/services/strategy-engine/strategies/macd_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py
index 7eaa59e..cbc8440 100644
--- a/services/strategy-engine/strategies/moc_strategy.py
+++ b/services/strategy-engine/strategies/moc_strategy.py
@@ -8,12 +8,12 @@ Rules:
"""
from collections import deque
-from decimal import Decimal
from datetime import datetime
+from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py
index 0646d8c..2df080d 100644
--- a/services/strategy-engine/strategies/rsi_strategy.py
+++ b/services/strategy-engine/strategies/rsi_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py
index ef2ae14..67b5c23 100644
--- a/services/strategy-engine/strategies/volume_profile_strategy.py
+++ b/services/strategy-engine/strategies/volume_profile_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import numpy as np
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -137,7 +137,7 @@ class VolumeProfileStrategy(BaseStrategy):
if result is None:
return None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ poc, va_low, va_high, hvn_levels, _lvn_levels = result
if close < va_low:
self._was_below_va = True
diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py
index d64950e..4ee4952 100644
--- a/services/strategy-engine/strategies/vwap_strategy.py
+++ b/services/strategy-engine/strategies/vwap_strategy.py
@@ -1,7 +1,7 @@
from collections import deque
from decimal import Decimal
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -107,7 +107,7 @@ class VwapStrategy(BaseStrategy):
# Standard deviation of (TP - VWAP) for bands
std_dev = 0.0
if len(self._tp_values) >= 2:
- diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)]
+ diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values, strict=True)]
mean_diff = sum(diffs) / len(diffs)
variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs)
std_dev = variance**0.5
diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py
index ae9ca05..66adec7 100644
--- a/services/strategy-engine/tests/test_base_filters.py
+++ b/services/strategy-engine/tests/test_base_filters.py
@@ -5,12 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy
+from shared.models import Candle, OrderSide, Signal
+
class DummyStrategy(BaseStrategy):
name = "dummy"
@@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None):
return Candle(
symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(h)),
low=Decimal(str(lo)),
diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py
index 8261377..70ec66e 100644
--- a/services/strategy-engine/tests/test_bollinger_strategy.py
+++ b/services/strategy-engine/tests/test_bollinger_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Bollinger Bands strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.bollinger_strategy import BollingerStrategy
from shared.models import Candle, OrderSide
-from strategies.bollinger_strategy import BollingerStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py
index 8a4dc74..6a15250 100644
--- a/services/strategy-engine/tests/test_combined_strategy.py
+++ b/services/strategy-engine/tests/test_combined_strategy.py
@@ -5,13 +5,14 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-import pytest
-from shared.models import Candle, Signal, OrderSide
-from strategies.combined_strategy import CombinedStrategy
+import pytest
from strategies.base import BaseStrategy
+from strategies.combined_strategy import CombinedStrategy
+
+from shared.models import Candle, OrderSide, Signal
class AlwaysBuyStrategy(BaseStrategy):
@@ -74,7 +75,7 @@ def _candle(price=100.0):
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price + 10)),
low=Decimal(str(price - 10)),
diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py
index 7028eb0..af2b587 100644
--- a/services/strategy-engine/tests/test_ema_crossover_strategy.py
+++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the EMA Crossover strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from shared.models import Candle, OrderSide
-from strategies.ema_crossover_strategy import EmaCrossoverStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py
index 2623027..fa888b5 100644
--- a/services/strategy-engine/tests/test_engine.py
+++ b/services/strategy-engine/tests/test_engine.py
@@ -1,21 +1,21 @@
"""Tests for the StrategyEngine."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import pytest
+from strategy_engine.engine import StrategyEngine
-from shared.models import Candle, Signal, OrderSide
from shared.events import CandleEvent
-from strategy_engine.engine import StrategyEngine
+from shared.models import Candle, OrderSide, Signal
def make_candle_event() -> dict:
candle = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("50100"),
low=Decimal("49900"),
diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py
index 878b900..f697012 100644
--- a/services/strategy-engine/tests/test_grid_strategy.py
+++ b/services/strategy-engine/tests/test_grid_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Grid strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.grid_strategy import GridStrategy
from shared.models import Candle, OrderSide
-from strategies.grid_strategy import GridStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py
index 481569b..3147fc4 100644
--- a/services/strategy-engine/tests/test_indicators.py
+++ b/services/strategy-engine/tests/test_indicators.py
@@ -5,14 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-import pandas as pd
import numpy as np
+import pandas as pd
import pytest
-
-from strategies.indicators.trend import sma, ema, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
class TestTrend:
diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py
index 556fd4c..7fac16f 100644
--- a/services/strategy-engine/tests/test_macd_strategy.py
+++ b/services/strategy-engine/tests/test_macd_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the MACD strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.macd_strategy import MacdStrategy
from shared.models import Candle, OrderSide
-from strategies.macd_strategy import MacdStrategy
def _candle(price: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price)),
low=Decimal(str(price)),
diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py
index 1928a28..076e846 100644
--- a/services/strategy-engine/tests/test_moc_strategy.py
+++ b/services/strategy-engine/tests/test_moc_strategy.py
@@ -5,19 +5,20 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from shared.models import Candle, OrderSide
from strategies.moc_strategy import MocStrategy
+from shared.models import Candle, OrderSide
+
def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None):
op = open_price if open_price is not None else price - 1 # Default: bullish
return Candle(
symbol="AAPL",
timeframe="5Min",
- open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC),
open=Decimal(str(op)),
high=Decimal(str(price + 1)),
low=Decimal(str(min(op, price) - 1)),
diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py
index 671a9d3..922bfc2 100644
--- a/services/strategy-engine/tests/test_multi_symbol.py
+++ b/services/strategy-engine/tests/test_multi_symbol.py
@@ -9,11 +9,13 @@ import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
+from decimal import Decimal
+
from strategy_engine.engine import StrategyEngine
+
from shared.events import CandleEvent
from shared.models import Candle
-from decimal import Decimal
-from datetime import datetime, timezone
@pytest.mark.asyncio
@@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams():
candle_btc = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
@@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams():
candle_eth = Candle(
symbol="MSFT",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("3000"),
high=Decimal("3100"),
low=Decimal("2900"),
diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py
index 5191fc3..7bd450f 100644
--- a/services/strategy-engine/tests/test_plugin_loader.py
+++ b/services/strategy-engine/tests/test_plugin_loader.py
@@ -2,10 +2,8 @@
from pathlib import Path
-
from strategy_engine.plugin_loader import load_strategies
-
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py
index 6d31fd5..6c74f0b 100644
--- a/services/strategy-engine/tests/test_rsi_strategy.py
+++ b/services/strategy-engine/tests/test_rsi_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the RSI strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.rsi_strategy import RsiStrategy
from shared.models import Candle, OrderSide
-from strategies.rsi_strategy import RsiStrategy
def make_candle(close: float, idx: int = 0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py
index ff9d09c..76b8541 100644
--- a/services/strategy-engine/tests/test_stock_selector.py
+++ b/services/strategy-engine/tests/test_stock_selector.py
@@ -1,12 +1,12 @@
"""Tests for stock selector engine."""
+from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
-from datetime import datetime, timezone
-
from strategy_engine.stock_selector import (
SentimentCandidateSource,
StockSelector,
+ _extract_json_array,
_parse_llm_selections,
)
@@ -60,6 +60,37 @@ def test_parse_llm_selections_with_markdown():
assert selections[0].symbol == "TSLA"
+def test_extract_json_array_from_markdown():
+ text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL", "score": 0.9}]
+
+
+def test_extract_json_array_bare():
+ text = '[{"symbol": "TSLA"}]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "TSLA"}]
+
+
+def test_extract_json_array_invalid():
+ assert _extract_json_array("not json") is None
+
+
+def test_extract_json_array_filters_non_dicts():
+ text = '[{"symbol": "AAPL"}, "bad", 42]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL"}]
+
+
+async def test_selector_close():
+ selector = StockSelector(
+ db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ # No session yet - close should be safe
+ await selector.close()
+ assert selector._http_session is None
+
+
async def test_selector_blocks_on_risk_off():
mock_db = MagicMock()
mock_db.get_latest_market_sentiment = AsyncMock(
@@ -69,7 +100,7 @@ async def test_selector_blocks_on_risk_off():
"vix": 35.0,
"fed_stance": "neutral",
"market_regime": "risk_off",
- "updated_at": datetime.now(timezone.utc),
+ "updated_at": datetime.now(UTC),
}
)
diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py
index debab1f..0d9607a 100644
--- a/services/strategy-engine/tests/test_strategy_validation.py
+++ b/services/strategy-engine/tests/test_strategy_validation.py
@@ -1,13 +1,11 @@
import pytest
-
-from strategies.rsi_strategy import RsiStrategy
-from strategies.macd_strategy import MacdStrategy
from strategies.bollinger_strategy import BollingerStrategy
from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from strategies.grid_strategy import GridStrategy
-from strategies.vwap_strategy import VwapStrategy
+from strategies.macd_strategy import MacdStrategy
+from strategies.rsi_strategy import RsiStrategy
from strategies.volume_profile_strategy import VolumeProfileStrategy
-
+from strategies.vwap_strategy import VwapStrategy
# ── RSI ──────────────────────────────────────────────────────────────────
diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py
index 65ee2e8..f47898c 100644
--- a/services/strategy-engine/tests/test_volume_profile_strategy.py
+++ b/services/strategy-engine/tests/test_volume_profile_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Volume Profile strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.volume_profile_strategy import VolumeProfileStrategy
from shared.models import Candle, OrderSide
-from strategies.volume_profile_strategy import VolumeProfileStrategy
def make_candle(close: float, volume: float = 1.0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
@@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection():
# Create a profile with very high volume at price ~100 and low volume elsewhere
# Prices range from 90 to 110, heavy volume concentrated at 100
- candles_data = []
# Low volume at extremes
- for p in [90, 91, 92, 109, 110]:
- candles_data.append((p, 1.0))
+ candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]]
# Very high volume around 100
- for _ in range(15):
- candles_data.append((100, 100.0))
+ candles_data.extend((100, 100.0) for _ in range(15))
for price, vol in candles_data:
strategy.on_candle(make_candle(price, vol))
@@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection():
# Access the internal method to verify HVN detection
result = strategy._compute_value_area()
assert result is not None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result
# The bin containing price ~100 should have very high volume -> HVN
assert len(hvn_levels) > 0
diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py
index 2c34b01..078d0cf 100644
--- a/services/strategy-engine/tests/test_vwap_strategy.py
+++ b/services/strategy-engine/tests/test_vwap_strategy.py
@@ -1,11 +1,11 @@
"""Tests for the VWAP strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.vwap_strategy import VwapStrategy
from shared.models import Candle, OrderSide
-from strategies.vwap_strategy import VwapStrategy
def make_candle(
@@ -20,7 +20,7 @@ def make_candle(
if low is None:
low = close
if open_time is None:
- open_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ open_time = datetime(2024, 1, 1, tzinfo=UTC)
return Candle(
symbol="AAPL",
timeframe="1m",
@@ -111,11 +111,11 @@ def test_vwap_daily_reset():
"""Candles from two different dates cause VWAP to reset."""
strategy = _configured_strategy()
- day1 = datetime(2024, 1, 1, tzinfo=timezone.utc)
- day2 = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ day1 = datetime(2024, 1, 1, tzinfo=UTC)
+ day2 = datetime(2024, 1, 2, tzinfo=UTC)
# Feed 35 candles on day 1 to build VWAP state
- for i in range(35):
+ for _i in range(35):
strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1))
# Verify state is built up
diff --git a/shared/alembic/versions/001_initial_schema.py b/shared/alembic/versions/001_initial_schema.py
index 2bdaafc..7b744ee 100644
--- a/shared/alembic/versions/001_initial_schema.py
+++ b/shared/alembic/versions/001_initial_schema.py
@@ -5,16 +5,16 @@ Revises:
Create Date: 2026-04-01
"""
-from typing import Sequence, Union
+from collections.abc import Sequence
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
revision: str = "001"
-down_revision: Union[str, None] = None
-branch_labels: Union[str, Sequence[str], None] = None
-depends_on: Union[str, Sequence[str], None] = None
+down_revision: str | None = None
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
diff --git a/shared/alembic/versions/002_news_sentiment_tables.py b/shared/alembic/versions/002_news_sentiment_tables.py
index 402ff87..d85a634 100644
--- a/shared/alembic/versions/002_news_sentiment_tables.py
+++ b/shared/alembic/versions/002_news_sentiment_tables.py
@@ -5,15 +5,15 @@ Revises: 001
Create Date: 2026-04-02
"""
-from typing import Sequence, Union
+from collections.abc import Sequence
-from alembic import op
import sqlalchemy as sa
+from alembic import op
revision: str = "002"
-down_revision: Union[str, None] = "001"
-branch_labels: Union[str, Sequence[str], None] = None
-depends_on: Union[str, Sequence[str], None] = None
+down_revision: str | None = "001"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
diff --git a/shared/alembic/versions/003_add_missing_indexes.py b/shared/alembic/versions/003_add_missing_indexes.py
new file mode 100644
index 0000000..7a252d4
--- /dev/null
+++ b/shared/alembic/versions/003_add_missing_indexes.py
@@ -0,0 +1,35 @@
+"""Add missing indexes for common query patterns.
+
+Revision ID: 003
+Revises: 002
+Create Date: 2026-04-02
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+
+revision: str = "003"
+down_revision: str | None = "002"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ op.create_index("idx_signals_symbol_created", "signals", ["symbol", "created_at"])
+ op.create_index(
+ "idx_orders_symbol_status_created", "orders", ["symbol", "status", "created_at"]
+ )
+ op.create_index("idx_trades_order_id", "trades", ["order_id"])
+ op.create_index("idx_trades_symbol_traded", "trades", ["symbol", "traded_at"])
+ op.create_index("idx_portfolio_snapshots_at", "portfolio_snapshots", ["snapshot_at"])
+ op.create_index("idx_symbol_scores_symbol", "symbol_scores", ["symbol"], unique=True)
+
+
+def downgrade() -> None:
+ op.drop_index("idx_symbol_scores_symbol", table_name="symbol_scores")
+ op.drop_index("idx_portfolio_snapshots_at", table_name="portfolio_snapshots")
+ op.drop_index("idx_trades_symbol_traded", table_name="trades")
+ op.drop_index("idx_trades_order_id", table_name="trades")
+ op.drop_index("idx_orders_symbol_status_created", table_name="orders")
+ op.drop_index("idx_signals_symbol_created", table_name="signals")
diff --git a/shared/alembic/versions/004_add_signal_detail_columns.py b/shared/alembic/versions/004_add_signal_detail_columns.py
new file mode 100644
index 0000000..4009b6e
--- /dev/null
+++ b/shared/alembic/versions/004_add_signal_detail_columns.py
@@ -0,0 +1,25 @@
+"""Add conviction, stop_loss, take_profit columns to signals table.
+
+Revision ID: 004
+Revises: 003
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+revision = "004"
+down_revision = "003"
+
+
+def upgrade():
+ op.add_column(
+ "signals", sa.Column("conviction", sa.Float, nullable=False, server_default="1.0")
+ )
+ op.add_column("signals", sa.Column("stop_loss", sa.Numeric, nullable=True))
+ op.add_column("signals", sa.Column("take_profit", sa.Numeric, nullable=True))
+
+
+def downgrade():
+ op.drop_column("signals", "take_profit")
+ op.drop_column("signals", "stop_loss")
+ op.drop_column("signals", "conviction")
diff --git a/shared/pyproject.toml b/shared/pyproject.toml
index 830088d..eb74a11 100644
--- a/shared/pyproject.toml
+++ b/shared/pyproject.toml
@@ -4,28 +4,22 @@ version = "0.1.0"
description = "Shared models, events, and utilities for trading platform"
requires-python = ">=3.12"
dependencies = [
- "pydantic>=2.0",
- "pydantic-settings>=2.0",
- "redis>=5.0",
- "asyncpg>=0.29",
- "sqlalchemy[asyncio]>=2.0",
- "alembic>=1.13",
- "structlog>=24.0",
- "prometheus-client>=0.20",
- "pyyaml>=6.0",
- "aiohttp>=3.9",
- "rich>=13.0",
+ "pydantic>=2.8,<3",
+ "pydantic-settings>=2.0,<3",
+ "redis>=5.0,<6",
+ "asyncpg>=0.29,<1",
+ "sqlalchemy[asyncio]>=2.0,<3",
+ "alembic>=1.13,<2",
+ "structlog>=24.0,<25",
+ "prometheus-client>=0.20,<1",
+ "pyyaml>=6.0,<7",
+ "aiohttp>=3.9,<4",
+ "rich>=13.0,<14",
]
[project.optional-dependencies]
-dev = [
- "pytest>=8.0",
- "pytest-asyncio>=0.23",
- "ruff>=0.4",
-]
-claude = [
- "anthropic>=0.40",
-]
+dev = ["pytest>=8.0,<9", "pytest-asyncio>=0.23,<1", "ruff>=0.4,<1"]
+claude = ["anthropic>=0.40,<1"]
[build-system]
requires = ["hatchling"]
diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py
index fbe4576..2b96714 100644
--- a/shared/src/shared/broker.py
+++ b/shared/src/shared/broker.py
@@ -5,13 +5,21 @@ from typing import Any
import redis.asyncio
+from shared.resilience import retry_async
+
class RedisBroker:
"""Async Redis Streams broker for publishing and reading events."""
def __init__(self, redis_url: str) -> None:
- self._redis = redis.asyncio.from_url(redis_url)
+ self._redis = redis.asyncio.from_url(
+ redis_url,
+ socket_keepalive=True,
+ health_check_interval=30,
+ retry_on_timeout=True,
+ )
+ @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,))
async def publish(self, stream: str, data: dict[str, Any]) -> None:
"""Publish a message to a Redis stream."""
payload = json.dumps(data)
@@ -25,6 +33,7 @@ class RedisBroker:
if "BUSYGROUP" not in str(e):
raise
+ @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,))
async def read_group(
self,
stream: str,
@@ -99,6 +108,7 @@ class RedisBroker:
messages.append(json.loads(payload))
return messages
+ @retry_async(max_retries=2, base_delay=0.5)
async def ping(self) -> bool:
"""Ping the Redis server; return True if reachable."""
return await self._redis.ping()
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
index 7a947b3..0f1c66e 100644
--- a/shared/src/shared/config.py
+++ b/shared/src/shared/config.py
@@ -1,14 +1,18 @@
"""Shared configuration settings for the trading platform."""
+from pydantic import SecretStr, field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
- alpaca_api_key: str = ""
- alpaca_api_secret: str = ""
+ alpaca_api_key: SecretStr = SecretStr("")
+ alpaca_api_secret: SecretStr = SecretStr("")
alpaca_paper: bool = True # Use paper trading by default
- redis_url: str = "redis://localhost:6379"
- database_url: str = "postgresql://trading:trading@localhost:5432/trading"
+ redis_url: SecretStr = SecretStr("redis://localhost:6379")
+ database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading")
+ db_pool_size: int = 20
+ db_max_overflow: int = 10
+ db_pool_recycle: int = 3600
log_level: str = "INFO"
risk_max_position_size: float = 0.1
risk_stop_loss_pct: float = 5.0
@@ -27,24 +31,45 @@ class Settings(BaseSettings):
risk_max_consecutive_losses: int = 5
risk_loss_pause_minutes: int = 60
dry_run: bool = True
- telegram_bot_token: str = ""
+ telegram_bot_token: SecretStr = SecretStr("")
telegram_chat_id: str = ""
telegram_enabled: bool = False
log_format: str = "json"
health_port: int = 8080
- circuit_breaker_threshold: int = 5
- circuit_breaker_timeout: int = 60
metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token
+ # API security
+ api_auth_token: SecretStr = SecretStr("")
+ cors_origins: str = "http://localhost:3000"
# News collector
- finnhub_api_key: str = ""
+ finnhub_api_key: SecretStr = SecretStr("")
news_poll_interval: int = 300
sentiment_aggregate_interval: int = 900
# Stock selector
- selector_candidates_time: str = "15:00"
- selector_filter_time: str = "15:15"
selector_final_time: str = "15:30"
selector_max_picks: int = 3
# LLM
- anthropic_api_key: str = ""
+ anthropic_api_key: SecretStr = SecretStr("")
anthropic_model: str = "claude-sonnet-4-20250514"
model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}
+
+ @field_validator("risk_max_position_size")
+ @classmethod
+ def validate_position_size(cls, v: float) -> float:
+ if v <= 0 or v > 1:
+ raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)")
+ return v
+
+ @field_validator("health_port")
+ @classmethod
+ def validate_health_port(cls, v: int) -> int:
+ if v < 1024 or v > 65535:
+ raise ValueError("health_port must be between 1024 and 65535")
+ return v
+
+ @field_validator("log_level")
+ @classmethod
+ def validate_log_level(cls, v: str) -> str:
+ valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
+ if v.upper() not in valid:
+ raise ValueError(f"log_level must be one of {valid}")
+ return v.upper()
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 9cc8686..8fee000 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -3,26 +3,25 @@
import json
import uuid
from contextlib import asynccontextmanager
-from datetime import datetime, date, timedelta, timezone
+from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
-from typing import Optional
from sqlalchemy import select, update
-from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
-from shared.models import Candle, Signal, Order, OrderStatus, NewsItem
-from shared.sentiment_models import SymbolScore, MarketSentiment
+from shared.models import Candle, NewsItem, Order, OrderStatus, Signal
from shared.sa_models import (
Base,
CandleRow,
- SignalRow,
+ MarketSentimentRow,
+ NewsItemRow,
OrderRow,
PortfolioSnapshotRow,
- NewsItemRow,
- SymbolScoreRow,
- MarketSentimentRow,
+ SignalRow,
StockSelectionRow,
+ SymbolScoreRow,
)
+from shared.sentiment_models import MarketSentiment, SymbolScore
class Database:
@@ -36,9 +35,24 @@ class Database:
self._engine = None
self._session_factory = None
- async def connect(self) -> None:
+ async def connect(
+ self,
+ pool_size: int = 20,
+ max_overflow: int = 10,
+ pool_recycle: int = 3600,
+ ) -> None:
"""Create the async engine, session factory, and all tables."""
- self._engine = create_async_engine(self._database_url)
+ if self._database_url.startswith("sqlite"):
+ # SQLite doesn't support pooling options
+ self._engine = create_async_engine(self._database_url)
+ else:
+ self._engine = create_async_engine(
+ self._database_url,
+ pool_pre_ping=True,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ )
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@@ -98,6 +112,9 @@ class Database:
price=signal.price,
quantity=signal.quantity,
reason=signal.reason,
+ conviction=signal.conviction,
+ stop_loss=signal.stop_loss,
+ take_profit=signal.take_profit,
created_at=signal.created_at,
)
async with self._session_factory() as session:
@@ -134,7 +151,7 @@ class Database:
self,
order_id: str,
status: OrderStatus,
- filled_at: Optional[datetime] = None,
+ filled_at: datetime | None = None,
) -> None:
"""Update the status (and optionally filled_at) of an order."""
stmt = (
@@ -180,7 +197,7 @@ class Database:
total_value=total_value,
realized_pnl=realized_pnl,
unrealized_pnl=unrealized_pnl,
- snapshot_at=datetime.now(timezone.utc),
+ snapshot_at=datetime.now(UTC),
)
session.add(row)
await session.commit()
@@ -191,7 +208,7 @@ class Database:
async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]:
"""Retrieve recent portfolio snapshots."""
async with self.get_session() as session:
- since = datetime.now(timezone.utc) - timedelta(days=days)
+ since = datetime.now(UTC) - timedelta(days=days)
stmt = (
select(PortfolioSnapshotRow)
.where(PortfolioSnapshotRow.snapshot_at >= since)
@@ -234,7 +251,7 @@ class Database:
async def get_recent_news(self, hours: int = 24) -> list[dict]:
"""Retrieve news items published in the last N hours."""
- since = datetime.now(timezone.utc) - timedelta(hours=hours)
+ since = datetime.now(UTC) - timedelta(hours=hours)
stmt = (
select(NewsItemRow)
.where(NewsItemRow.published_at >= since)
@@ -352,7 +369,7 @@ class Database:
await session.rollback()
raise
- async def get_latest_market_sentiment(self) -> Optional[dict]:
+ async def get_latest_market_sentiment(self) -> dict | None:
"""Retrieve the 'latest' market sentiment row, or None if not found."""
stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest")
async with self._session_factory() as session:
@@ -394,7 +411,7 @@ class Database:
reason=reason,
key_news=json.dumps(key_news),
sentiment_snapshot=json.dumps(sentiment_snapshot),
- created_at=datetime.now(timezone.utc),
+ created_at=datetime.now(UTC),
)
async with self._session_factory() as session:
try:
diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py
index 63f93a2..37217a0 100644
--- a/shared/src/shared/events.py
+++ b/shared/src/shared/events.py
@@ -1,14 +1,14 @@
"""Event types and serialization for the trading platform."""
-from enum import Enum
+from enum import StrEnum
from typing import Any
from pydantic import BaseModel
-from shared.models import Candle, Signal, Order, NewsItem
+from shared.models import Candle, NewsItem, Order, Signal
-class EventType(str, Enum):
+class EventType(StrEnum):
CANDLE = "CANDLE"
SIGNAL = "SIGNAL"
ORDER = "ORDER"
@@ -88,6 +88,16 @@ class Event:
@staticmethod
def from_dict(data: dict) -> Any:
- event_type = EventType(data["type"])
+ """Deserialize a raw dict into the appropriate event type.
+
+ Raises ValueError for malformed or unrecognized event data.
+ """
+ try:
+ event_type = EventType(data["type"])
+ except (KeyError, ValueError) as exc:
+ raise ValueError(f"Invalid or missing event type in data: {data!r}") from exc
cls = _EVENT_TYPE_MAP[event_type]
- return cls.from_raw(data)
+ try:
+ return cls.from_raw(data)
+ except KeyError as exc:
+ raise ValueError(f"Missing required field in {event_type} event data: {exc}") from exc
diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py
index 7411e8a..a19705b 100644
--- a/shared/src/shared/healthcheck.py
+++ b/shared/src/shared/healthcheck.py
@@ -3,10 +3,11 @@
from __future__ import annotations
import time
-from typing import Any, Callable, Awaitable
+from collections.abc import Awaitable, Callable
+from typing import Any
from aiohttp import web
-from prometheus_client import CollectorRegistry, REGISTRY, generate_latest, CONTENT_TYPE_LATEST
+from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, CollectorRegistry, generate_latest
class HealthCheckServer:
diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py
index cd239f3..6189143 100644
--- a/shared/src/shared/metrics.py
+++ b/shared/src/shared/metrics.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY
+from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge, Histogram
class ServiceMetrics:
diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py
index a436c03..f357f9f 100644
--- a/shared/src/shared/models.py
+++ b/shared/src/shared/models.py
@@ -1,25 +1,24 @@
"""Shared Pydantic models for the trading platform."""
import uuid
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-from enum import Enum
-from typing import Optional
+from enum import StrEnum
from pydantic import BaseModel, Field, computed_field
-class OrderSide(str, Enum):
+class OrderSide(StrEnum):
BUY = "BUY"
SELL = "SELL"
-class OrderType(str, Enum):
+class OrderType(StrEnum):
MARKET = "MARKET"
LIMIT = "LIMIT"
-class OrderStatus(str, Enum):
+class OrderStatus(StrEnum):
PENDING = "PENDING"
FILLED = "FILLED"
CANCELLED = "CANCELLED"
@@ -46,9 +45,9 @@ class Signal(BaseModel):
quantity: Decimal
reason: str
conviction: float = 1.0 # 0.0 to 1.0, signal strength/confidence
- stop_loss: Optional[Decimal] = None # Price to exit at loss
- take_profit: Optional[Decimal] = None # Price to exit at profit
- created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ stop_loss: Decimal | None = None # Price to exit at loss
+ take_profit: Decimal | None = None # Price to exit at profit
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
class Order(BaseModel):
@@ -60,8 +59,8 @@ class Order(BaseModel):
price: Decimal
quantity: Decimal
status: OrderStatus = OrderStatus.PENDING
- created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
- filled_at: Optional[datetime] = None
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
+ filled_at: datetime | None = None
class Position(BaseModel):
@@ -76,7 +75,7 @@ class Position(BaseModel):
return self.quantity * (self.current_price - self.avg_entry_price)
-class NewsCategory(str, Enum):
+class NewsCategory(StrEnum):
POLICY = "policy"
EARNINGS = "earnings"
MACRO = "macro"
@@ -89,11 +88,11 @@ class NewsItem(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
source: str
headline: str
- summary: Optional[str] = None
- url: Optional[str] = None
+ summary: str | None = None
+ url: str | None = None
published_at: datetime
symbols: list[str] = []
sentiment: float
category: NewsCategory
raw_data: dict = {}
- created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
diff --git a/shared/src/shared/notifier.py b/shared/src/shared/notifier.py
index 3d7b6cf..cfc86cd 100644
--- a/shared/src/shared/notifier.py
+++ b/shared/src/shared/notifier.py
@@ -2,13 +2,13 @@
import asyncio
import logging
+from collections.abc import Sequence
from decimal import Decimal
-from typing import Optional, Sequence
import aiohttp
-from shared.models import Signal, Order, Position
-from shared.sentiment_models import SelectedStock, MarketSentiment
+from shared.models import Order, Position, Signal
+from shared.sentiment_models import MarketSentiment, SelectedStock
logger = logging.getLogger(__name__)
@@ -23,7 +23,7 @@ class TelegramNotifier:
self._bot_token = bot_token
self._chat_id = chat_id
self._semaphore = asyncio.Semaphore(1)
- self._session: Optional[aiohttp.ClientSession] = None
+ self._session: aiohttp.ClientSession | None = None
@property
def enabled(self) -> bool:
@@ -113,13 +113,13 @@ class TelegramNotifier:
"",
"<b>Positions:</b>",
]
- for pos in positions:
- lines.append(
- f" {pos.symbol}: qty={pos.quantity} "
- f"entry={pos.avg_entry_price} "
- f"current={pos.current_price} "
- f"pnl={pos.unrealized_pnl}"
- )
+ lines.extend(
+ f" {pos.symbol}: qty={pos.quantity} "
+ f"entry={pos.avg_entry_price} "
+ f"current={pos.current_price} "
+ f"pnl={pos.unrealized_pnl}"
+ for pos in positions
+ )
if not positions:
lines.append(" No open positions")
await self.send("\n".join(lines))
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
index e43fd21..66225d7 100644
--- a/shared/src/shared/resilience.py
+++ b/shared/src/shared/resilience.py
@@ -1,29 +1,45 @@
-"""Retry with exponential backoff and circuit breaker utilities."""
+"""Resilience utilities for the trading platform.
+
+Provides retry, circuit breaker, and timeout primitives using only stdlib.
+No external dependencies required.
+"""
from __future__ import annotations
import asyncio
-import enum
import functools
import logging
import random
import time
-from typing import Any, Callable
+from collections.abc import Callable
+from contextlib import asynccontextmanager
+from enum import StrEnum
+from typing import Any
-logger = logging.getLogger(__name__)
+
+class _State(StrEnum):
+ CLOSED = "closed"
+ OPEN = "open"
+ HALF_OPEN = "half_open"
-# ---------------------------------------------------------------------------
-# retry_with_backoff
-# ---------------------------------------------------------------------------
+logger = logging.getLogger(__name__)
-def retry_with_backoff(
+def retry_async(
max_retries: int = 3,
base_delay: float = 1.0,
- max_delay: float = 60.0,
+ max_delay: float = 30.0,
+ exclude: tuple[type[BaseException], ...] = (),
) -> Callable:
- """Decorator that retries an async function with exponential backoff + jitter."""
+ """Decorator: exponential backoff + jitter for async functions.
+
+ Parameters:
+ max_retries: Maximum number of retry attempts (after the initial call).
+ base_delay: Base delay in seconds for exponential backoff.
+ max_delay: Maximum delay cap in seconds.
+ exclude: Exception types that should NOT be retried (raised immediately).
+ """
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
@@ -33,20 +49,21 @@ def retry_with_backoff(
try:
return await func(*args, **kwargs)
except Exception as exc:
+ if exclude and isinstance(exc, exclude):
+ raise
last_exc = exc
if attempt < max_retries:
delay = min(base_delay * (2**attempt), max_delay)
- jitter = delay * random.uniform(0, 0.5)
- total_delay = delay + jitter
+ jitter_delay = delay * random.uniform(0.5, 1.0)
logger.warning(
- "Retry %d/%d for %s after error: %s (delay=%.3fs)",
+ "Retry %d/%d for %s in %.2fs: %s",
attempt + 1,
max_retries,
func.__name__,
+ jitter_delay,
exc,
- total_delay,
)
- await asyncio.sleep(total_delay)
+ await asyncio.sleep(jitter_delay)
raise last_exc # type: ignore[misc]
return wrapper
@@ -54,52 +71,65 @@ def retry_with_backoff(
return decorator
-# ---------------------------------------------------------------------------
-# CircuitBreaker
-# ---------------------------------------------------------------------------
-
-
-class CircuitState(enum.Enum):
- CLOSED = "closed"
- OPEN = "open"
- HALF_OPEN = "half_open"
+class CircuitBreaker:
+ """Circuit breaker: opens after N consecutive failures, auto-recovers.
+ States: closed -> open -> half_open -> closed
-class CircuitBreaker:
- """Simple circuit breaker implementation."""
+ Parameters:
+ failure_threshold: Number of consecutive failures before opening.
+ cooldown: Seconds to wait before allowing a half-open probe.
+ """
- def __init__(
- self,
- failure_threshold: int = 5,
- recovery_timeout: float = 60.0,
- ) -> None:
+ def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None:
self._failure_threshold = failure_threshold
- self._recovery_timeout = recovery_timeout
- self._failure_count: int = 0
- self._state = CircuitState.CLOSED
+ self._cooldown = cooldown
+ self._failures = 0
+ self._state = _State.CLOSED
self._opened_at: float = 0.0
- @property
- def state(self) -> CircuitState:
- return self._state
-
- def allow_request(self) -> bool:
- if self._state == CircuitState.CLOSED:
- return True
- if self._state == CircuitState.OPEN:
- if time.monotonic() - self._opened_at >= self._recovery_timeout:
- self._state = CircuitState.HALF_OPEN
- return True
- return False
- # HALF_OPEN
- return True
-
- def record_success(self) -> None:
- self._failure_count = 0
- self._state = CircuitState.CLOSED
-
- def record_failure(self) -> None:
- self._failure_count += 1
- if self._failure_count >= self._failure_threshold:
- self._state = CircuitState.OPEN
- self._opened_at = time.monotonic()
+ async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+ """Execute func through the breaker."""
+ if self._state == _State.OPEN:
+ if time.monotonic() - self._opened_at >= self._cooldown:
+ self._state = _State.HALF_OPEN
+ else:
+ raise RuntimeError("Circuit breaker is open")
+
+ try:
+ result = await func(*args, **kwargs)
+ except Exception:
+ self._failures += 1
+ if self._state == _State.HALF_OPEN:
+ self._state = _State.OPEN
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker re-opened after half-open probe failure (threshold=%d)",
+ self._failure_threshold,
+ )
+ elif self._failures >= self._failure_threshold:
+ self._state = _State.OPEN
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker opened after %d consecutive failures",
+ self._failures,
+ )
+ raise
+
+ # Success: reset
+ self._failures = 0
+ self._state = _State.CLOSED
+ return result
+
+
+@asynccontextmanager
+async def async_timeout(seconds: float):
+ """Async context manager wrapping asyncio.timeout().
+
+ Raises TimeoutError with a descriptive message on timeout.
+ """
+ try:
+ async with asyncio.timeout(seconds):
+ yield
+ except TimeoutError:
+ raise TimeoutError(f"Operation timed out after {seconds}s") from None
diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py
index 1bd92c2..b70a6c4 100644
--- a/shared/src/shared/sa_models.py
+++ b/shared/src/shared/sa_models.py
@@ -35,6 +35,9 @@ class SignalRow(Base):
price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
reason: Mapped[str | None] = mapped_column(Text)
+ conviction: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="1.0")
+ stop_loss: Mapped[Decimal | None] = mapped_column(Numeric)
+ take_profit: Mapped[Decimal | None] = mapped_column(Numeric)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py
index 449eb76..c56da3e 100644
--- a/shared/src/shared/sentiment.py
+++ b/shared/src/shared/sentiment.py
@@ -1,41 +1,10 @@
-"""Market sentiment data."""
+"""Market sentiment aggregation."""
-import logging
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
+from datetime import datetime
+from typing import ClassVar
from shared.sentiment_models import SymbolScore
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class SentimentData:
- """Aggregated sentiment snapshot."""
-
- fear_greed_value: int | None = None
- fear_greed_label: str | None = None
- news_sentiment: float | None = None
- news_count: int = 0
- exchange_netflow: float | None = None
- timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
-
- @property
- def should_buy(self) -> bool:
- if self.fear_greed_value is not None and self.fear_greed_value > 70:
- return False
- if self.news_sentiment is not None and self.news_sentiment < -0.3:
- return False
- return True
-
- @property
- def should_block(self) -> bool:
- if self.fear_greed_value is not None and self.fear_greed_value > 80:
- return True
- if self.news_sentiment is not None and self.news_sentiment < -0.5:
- return True
- return False
-
def _safe_avg(values: list[float]) -> float:
if not values:
@@ -46,9 +15,9 @@ def _safe_avg(values: list[float]) -> float:
class SentimentAggregator:
"""Aggregates per-news sentiment into per-symbol scores."""
- WEIGHTS = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2}
+ WEIGHTS: ClassVar[dict[str, float]] = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2}
- CATEGORY_MAP = {
+ CATEGORY_MAP: ClassVar[dict[str, str]] = {
"earnings": "news",
"macro": "news",
"social": "social",
diff --git a/shared/src/shared/sentiment_models.py b/shared/src/shared/sentiment_models.py
index a009601..ac06c20 100644
--- a/shared/src/shared/sentiment_models.py
+++ b/shared/src/shared/sentiment_models.py
@@ -1,7 +1,6 @@
"""Sentiment scoring and stock selection models."""
from datetime import datetime
-from typing import Optional
from pydantic import BaseModel
@@ -22,7 +21,7 @@ class SymbolScore(BaseModel):
class MarketSentiment(BaseModel):
fear_greed: int
fear_greed_label: str
- vix: Optional[float] = None
+ vix: float | None = None
fed_stance: str
market_regime: str
updated_at: datetime
@@ -39,6 +38,6 @@ class SelectedStock(BaseModel):
class Candidate(BaseModel):
symbol: str
source: str
- direction: Optional[OrderSide] = None
+ direction: OrderSide | None = None
score: float
reason: str
diff --git a/shared/src/shared/shutdown.py b/shared/src/shared/shutdown.py
new file mode 100644
index 0000000..4ed9aa7
--- /dev/null
+++ b/shared/src/shared/shutdown.py
@@ -0,0 +1,30 @@
+"""Graceful shutdown utilities for services."""
+
+import asyncio
+import logging
+import signal
+
+logger = logging.getLogger(__name__)
+
+
+class GracefulShutdown:
+ """Manages graceful shutdown via SIGTERM/SIGINT signals."""
+
+ def __init__(self) -> None:
+ self._event = asyncio.Event()
+
+ @property
+ def is_shutting_down(self) -> bool:
+ return self._event.is_set()
+
+ async def wait(self) -> None:
+ await self._event.wait()
+
+ def trigger(self) -> None:
+ logger.info("shutdown_signal_received")
+ self._event.set()
+
+ def install_handlers(self) -> None:
+ loop = asyncio.get_running_loop()
+ for sig in (signal.SIGTERM, signal.SIGINT):
+ loop.add_signal_handler(sig, self.trigger)
diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py
index 080b7c4..55a2b24 100644
--- a/shared/tests/test_alpaca.py
+++ b/shared/tests/test_alpaca.py
@@ -1,7 +1,9 @@
"""Tests for Alpaca API client."""
-import pytest
from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
from shared.alpaca import AlpacaClient
diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py
index eb1582d..5636611 100644
--- a/shared/tests/test_broker.py
+++ b/shared/tests/test_broker.py
@@ -1,10 +1,11 @@
"""Tests for the Redis broker."""
-import pytest
import json
-import redis
from unittest.mock import AsyncMock, patch
+import pytest
+import redis
+
@pytest.mark.asyncio
async def test_broker_publish():
diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py
new file mode 100644
index 0000000..9376dc6
--- /dev/null
+++ b/shared/tests/test_config_validation.py
@@ -0,0 +1,29 @@
+"""Tests for config validation."""
+
+import pytest
+from pydantic import ValidationError
+
+from shared.config import Settings
+
+
+class TestConfigValidation:
+ def test_valid_defaults(self):
+ settings = Settings()
+ assert settings.risk_max_position_size == 0.1
+
+ def test_invalid_position_size(self):
+ with pytest.raises(ValidationError, match="risk_max_position_size"):
+ Settings(risk_max_position_size=-0.1)
+
+ def test_invalid_health_port(self):
+ with pytest.raises(ValidationError, match="health_port"):
+ Settings(health_port=80)
+
+ def test_invalid_log_level(self):
+ with pytest.raises(ValidationError, match="log_level"):
+ Settings(log_level="INVALID")
+
+ def test_secret_fields_masked(self):
+ settings = Settings(alpaca_api_key="my-secret-key")
+ assert "my-secret-key" not in repr(settings)
+ assert settings.alpaca_api_key.get_secret_value() == "my-secret-key"
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 239ee64..b44a713 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,10 +1,11 @@
"""Tests for the SQLAlchemy async database layer."""
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
+import pytest
+
def make_candle():
from shared.models import Candle
@@ -12,7 +13,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -22,7 +23,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
id="sig-1",
@@ -32,12 +33,12 @@ def make_signal():
price=Decimal("50000"),
quantity=Decimal("0.1"),
reason="Golden cross",
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def make_order():
- from shared.models import Order, OrderSide, OrderType, OrderStatus
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
return Order(
id="ord-1",
@@ -48,7 +49,7 @@ def make_order():
price=Decimal("50000"),
quantity=Decimal("0.1"),
status=OrderStatus.PENDING,
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
@@ -101,6 +102,54 @@ class TestDatabaseConnect:
mock_create.assert_called_once()
@pytest.mark.asyncio
+ async def test_connect_passes_pool_params_for_postgres(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800)
+ mock_create.assert_called_once_with(
+ "postgresql+asyncpg://host/db",
+ pool_pre_ping=True,
+ pool_size=5,
+ max_overflow=3,
+ pool_recycle=1800,
+ )
+
+ @pytest.mark.asyncio
+ async def test_connect_skips_pool_params_for_sqlite(self):
+ from shared.db import Database
+
+ db = Database("sqlite+aiosqlite:///test.db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect()
+ mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db")
+
+ @pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database
@@ -211,7 +260,7 @@ class TestUpdateOrderStatus:
db._session_factory = MagicMock(return_value=mock_session)
- filled = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ filled = datetime(2024, 1, 2, tzinfo=UTC)
await db.update_order_status("ord-1", OrderStatus.FILLED, filled)
mock_session.execute.assert_awaited_once()
@@ -230,7 +279,7 @@ class TestGetCandles:
mock_row._mapping = {
"symbol": "AAPL",
"timeframe": "1m",
- "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc),
+ "open_time": datetime(2024, 1, 1, tzinfo=UTC),
"open": Decimal("50000"),
"high": Decimal("51000"),
"low": Decimal("49500"),
@@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots:
mock_row.total_value = Decimal("10000")
mock_row.realized_pnl = Decimal("0")
mock_row.unrealized_pnl = Decimal("500")
- mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_row]
diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py
index a2c9140..c184bed 100644
--- a/shared/tests/test_db_news.py
+++ b/shared/tests/test_db_news.py
@@ -1,11 +1,12 @@
"""Tests for database news/sentiment methods. Uses in-memory SQLite."""
+from datetime import UTC, date, datetime
+
import pytest
-from datetime import datetime, date, timezone
from shared.db import Database
-from shared.models import NewsItem, NewsCategory
-from shared.sentiment_models import SymbolScore, MarketSentiment
+from shared.models import NewsCategory, NewsItem
+from shared.sentiment_models import MarketSentiment, SymbolScore
@pytest.fixture
@@ -20,7 +21,7 @@ async def test_insert_and_get_news_items(db):
item = NewsItem(
source="finnhub",
headline="AAPL earnings beat",
- published_at=datetime(2026, 4, 2, 12, 0, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
sentiment=0.8,
category=NewsCategory.EARNINGS,
symbols=["AAPL"],
@@ -40,7 +41,7 @@ async def test_upsert_symbol_score(db):
policy_score=0.0,
filing_score=0.2,
composite=0.3,
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
await db.upsert_symbol_score(score)
scores = await db.get_top_symbol_scores(limit=5)
@@ -55,7 +56,7 @@ async def test_upsert_market_sentiment(db):
vix=18.2,
fed_stance="neutral",
market_regime="neutral",
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
await db.upsert_market_sentiment(ms)
result = await db.get_latest_market_sentiment()
diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py
index 6077d93..1ccd904 100644
--- a/shared/tests/test_events.py
+++ b/shared/tests/test_events.py
@@ -1,7 +1,7 @@
"""Tests for shared event types."""
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
def make_candle():
@@ -10,7 +10,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -20,7 +20,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
strategy="test",
@@ -59,7 +59,7 @@ def test_candle_event_deserialize():
def test_signal_event_serialize():
"""Test SignalEvent serializes to dict correctly."""
- from shared.events import SignalEvent, EventType
+ from shared.events import EventType, SignalEvent
signal = make_signal()
event = SignalEvent(data=signal)
@@ -71,7 +71,7 @@ def test_signal_event_serialize():
def test_event_from_dict_dispatch():
"""Test Event.from_dict dispatches to correct class."""
- from shared.events import Event, CandleEvent, SignalEvent
+ from shared.events import CandleEvent, Event, SignalEvent
candle = make_candle()
event = CandleEvent(data=candle)
diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py
index 04098ce..40bb791 100644
--- a/shared/tests/test_models.py
+++ b/shared/tests/test_models.py
@@ -1,8 +1,8 @@
"""Tests for shared models and settings."""
import os
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import patch
@@ -12,8 +12,11 @@ def test_settings_defaults():
with patch.dict(os.environ, {}, clear=False):
settings = Settings()
- assert settings.redis_url == "redis://localhost:6379"
- assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading"
+ assert settings.redis_url.get_secret_value() == "redis://localhost:6379"
+ assert (
+ settings.database_url.get_secret_value()
+ == "postgresql://trading:trading@localhost:5432/trading"
+ )
assert settings.log_level == "INFO"
assert settings.risk_max_position_size == 0.1
assert settings.risk_stop_loss_pct == 5.0
@@ -25,7 +28,7 @@ def test_candle_creation():
"""Test Candle model creation."""
from shared.models import Candle
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
candle = Candle(
symbol="AAPL",
timeframe="1m",
@@ -47,7 +50,7 @@ def test_candle_creation():
def test_signal_creation():
"""Test Signal model creation."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi_strategy",
@@ -69,9 +72,10 @@ def test_signal_creation():
def test_order_creation():
"""Test Order model creation with defaults."""
- from shared.models import Order, OrderSide, OrderType, OrderStatus
import uuid
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
+
signal_id = str(uuid.uuid4())
order = Order(
signal_id=signal_id,
@@ -90,7 +94,7 @@ def test_order_creation():
def test_signal_conviction_default():
"""Test Signal defaults for conviction, stop_loss, take_profit."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi",
@@ -107,7 +111,7 @@ def test_signal_conviction_default():
def test_signal_with_stops():
"""Test Signal with explicit conviction, stop_loss, take_profit."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi",
diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py
index 384796a..f748d8a 100644
--- a/shared/tests/test_news_events.py
+++ b/shared/tests/test_news_events.py
@@ -1,16 +1,16 @@
"""Tests for NewsEvent."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
+from shared.events import Event, EventType, NewsEvent
from shared.models import NewsCategory, NewsItem
-from shared.events import NewsEvent, EventType, Event
def test_news_event_to_dict():
item = NewsItem(
source="finnhub",
headline="Test",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.5,
category=NewsCategory.MACRO,
)
diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py
index 6c81369..cc98a56 100644
--- a/shared/tests/test_notifier.py
+++ b/shared/tests/test_notifier.py
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position
+from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal
from shared.notifier import TelegramNotifier
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
index e287777..e0781af 100644
--- a/shared/tests/test_resilience.py
+++ b/shared/tests/test_resilience.py
@@ -1,139 +1,176 @@
-"""Tests for retry with backoff and circuit breaker."""
+"""Tests for shared.resilience module."""
-import time
+import asyncio
import pytest
-from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff
+from shared.resilience import CircuitBreaker, async_timeout, retry_async
+# --- retry_async tests ---
-# ---------------------------------------------------------------------------
-# retry_with_backoff tests
-# ---------------------------------------------------------------------------
-
-@pytest.mark.asyncio
-async def test_retry_succeeds_first_try():
+async def test_succeeds_without_retry():
+ """Function succeeds first try, called once."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def succeed():
+ @retry_async()
+ async def fn():
nonlocal call_count
call_count += 1
return "ok"
- result = await succeed()
+ result = await fn()
assert result == "ok"
assert call_count == 1
-@pytest.mark.asyncio
-async def test_retry_succeeds_after_failures():
+async def test_retries_on_failure_then_succeeds():
+ """Fails twice then succeeds, verify call count."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def flaky():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
if call_count < 3:
- raise ValueError("not yet")
+ raise RuntimeError("transient")
return "recovered"
- result = await flaky()
+ result = await fn()
assert result == "recovered"
assert call_count == 3
-@pytest.mark.asyncio
-async def test_retry_raises_after_max_retries():
+async def test_raises_after_max_retries():
+ """Always fails, raises after max retries."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def always_fail():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
- raise RuntimeError("permanent")
+ raise ValueError("permanent")
- with pytest.raises(RuntimeError, match="permanent"):
- await always_fail()
- # 1 initial + 3 retries = 4 calls
+ with pytest.raises(ValueError, match="permanent"):
+ await fn()
+
+ # 1 initial + 3 retries = 4 total calls
assert call_count == 4
-@pytest.mark.asyncio
-async def test_retry_respects_max_delay():
- """Backoff should be capped at max_delay."""
+async def test_no_retry_on_excluded_exception():
+ """Excluded exception raises immediately, call count = 1."""
+ call_count = 0
- @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
- async def always_fail():
- raise RuntimeError("fail")
+ @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,))
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ raise TypeError("excluded")
- start = time.monotonic()
- with pytest.raises(RuntimeError):
- await always_fail()
- elapsed = time.monotonic() - start
- # With max_delay=0.02 and 2 retries, total delay should be small
- assert elapsed < 0.5
+ with pytest.raises(TypeError, match="excluded"):
+ await fn()
+
+ assert call_count == 1
-# ---------------------------------------------------------------------------
-# CircuitBreaker tests
-# ---------------------------------------------------------------------------
+# --- CircuitBreaker tests ---
-def test_circuit_starts_closed():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05)
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_closed_allows_calls():
+ """CircuitBreaker in closed state passes through."""
+ cb = CircuitBreaker(failure_threshold=5, cooldown=60.0)
+ async def fn():
+ return "ok"
+
+ result = await cb.call(fn)
+ assert result == "ok"
+
+
+async def test_opens_after_threshold():
+ """After N failures, raises RuntimeError."""
+ cb = CircuitBreaker(failure_threshold=3, cooldown=60.0)
+
+ async def fail():
+ raise RuntimeError("fail")
-def test_circuit_opens_after_threshold():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0)
for _ in range(3):
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+ # Now the breaker should be open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+
+async def test_half_open_after_cooldown():
+ """After cooldown, allows recovery attempt."""
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def fail():
+ raise RuntimeError("fail")
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+
+ # Breaker is open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.06)
+
+ # Now should allow a call (half_open). Succeed to close it.
+ async def succeed():
+ return "recovered"
+
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+ # Breaker should be closed again
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+
+async def test_half_open_reopens_on_failure():
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def always_fail():
+ raise ConnectionError("fail")
-def test_circuit_rejects_when_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
+ # Wait for cooldown
+ await asyncio.sleep(0.1)
-def test_circuit_half_open_after_timeout():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+ # Half-open probe should fail and re-open
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
- time.sleep(0.06)
- assert cb.allow_request() is True
- assert cb.state == CircuitState.HALF_OPEN
+ # Should be open again (no cooldown wait)
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(always_fail)
-def test_circuit_closes_on_success():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+# --- async_timeout tests ---
- time.sleep(0.06)
- cb.allow_request() # triggers HALF_OPEN
- assert cb.state == CircuitState.HALF_OPEN
- cb.record_success()
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_completes_within_timeout():
+ """async_timeout doesn't interfere with fast operations."""
+ async with async_timeout(1.0):
+ await asyncio.sleep(0.01)
+ result = 42
+ assert result == 42
-def test_circuit_reopens_on_failure_in_half_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- time.sleep(0.06)
- cb.allow_request() # HALF_OPEN
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+async def test_raises_on_timeout():
+ """async_timeout raises TimeoutError for slow operations."""
+ with pytest.raises(TimeoutError):
+ async with async_timeout(0.05):
+ await asyncio.sleep(1.0)
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py
index dc6355e..c9311dd 100644
--- a/shared/tests/test_sa_models.py
+++ b/shared/tests/test_sa_models.py
@@ -72,6 +72,9 @@ class TestSignalRow:
"price",
"quantity",
"reason",
+ "conviction",
+ "stop_loss",
+ "take_profit",
"created_at",
}
assert expected == cols
@@ -124,44 +127,6 @@ class TestOrderRow:
assert fk_cols == {"signal_id": "signals.id"}
-class TestTradeRow:
- def test_table_name(self):
- from shared.sa_models import TradeRow
-
- assert TradeRow.__tablename__ == "trades"
-
- def test_columns(self):
- from shared.sa_models import TradeRow
-
- mapper = inspect(TradeRow)
- cols = {c.key for c in mapper.column_attrs}
- expected = {
- "id",
- "order_id",
- "symbol",
- "side",
- "price",
- "quantity",
- "fee",
- "traded_at",
- }
- assert expected == cols
-
- def test_primary_key(self):
- from shared.sa_models import TradeRow
-
- mapper = inspect(TradeRow)
- pk_cols = [c.name for c in mapper.mapper.primary_key]
- assert pk_cols == ["id"]
-
- def test_order_id_foreign_key(self):
- from shared.sa_models import TradeRow
-
- table = TradeRow.__table__
- fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys}
- assert fk_cols == {"order_id": "orders.id"}
-
-
class TestPositionRow:
def test_table_name(self):
from shared.sa_models import PositionRow
@@ -233,11 +198,3 @@ class TestStatusDefault:
status_col = table.c.status
assert status_col.server_default is not None
assert status_col.server_default.arg == "PENDING"
-
- def test_trade_fee_server_default(self):
- from shared.sa_models import TradeRow
-
- table = TradeRow.__table__
- fee_col = table.c.fee
- assert fee_col.server_default is not None
- assert fee_col.server_default.arg == "0"
diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py
index 91e6d4a..dc2d026 100644
--- a/shared/tests/test_sa_news_models.py
+++ b/shared/tests/test_sa_news_models.py
@@ -1,6 +1,6 @@
"""Tests for news-related SQLAlchemy models."""
-from shared.sa_models import NewsItemRow, SymbolScoreRow, MarketSentimentRow, StockSelectionRow
+from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow
def test_news_item_row_tablename():
diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py
deleted file mode 100644
index 9bd8ea3..0000000
--- a/shared/tests/test_sentiment.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Tests for market sentiment module."""
-
-from shared.sentiment import SentimentData
-
-
-def test_sentiment_should_buy_default_no_data():
- s = SentimentData()
- assert s.should_buy is True
- assert s.should_block is False
-
-
-def test_sentiment_should_buy_low_fear_greed():
- s = SentimentData(fear_greed_value=15)
- assert s.should_buy is True
-
-
-def test_sentiment_should_not_buy_on_greed():
- s = SentimentData(fear_greed_value=75)
- assert s.should_buy is False
-
-
-def test_sentiment_should_not_buy_negative_news():
- s = SentimentData(news_sentiment=-0.4)
- assert s.should_buy is False
-
-
-def test_sentiment_should_buy_positive_news():
- s = SentimentData(fear_greed_value=50, news_sentiment=0.3)
- assert s.should_buy is True
-
-
-def test_sentiment_should_block_extreme_greed():
- s = SentimentData(fear_greed_value=85)
- assert s.should_block is True
-
-
-def test_sentiment_should_block_very_negative_news():
- s = SentimentData(news_sentiment=-0.6)
- assert s.should_block is True
-
-
-def test_sentiment_no_block_on_neutral():
- s = SentimentData(fear_greed_value=50, news_sentiment=0.0)
- assert s.should_block is False
diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py
index a99c711..9193785 100644
--- a/shared/tests/test_sentiment_aggregator.py
+++ b/shared/tests/test_sentiment_aggregator.py
@@ -1,7 +1,9 @@
"""Tests for sentiment aggregator."""
+from datetime import UTC, datetime, timedelta
+
import pytest
-from datetime import datetime, timezone, timedelta
+
from shared.sentiment import SentimentAggregator
@@ -12,25 +14,25 @@ def aggregator():
def test_freshness_decay_recent():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now, now) == 1.0
def test_freshness_decay_3_hours():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7
def test_freshness_decay_12_hours():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3
def test_freshness_decay_old():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now - timedelta(days=2), now) == 0.0
@@ -44,7 +46,7 @@ def test_compute_composite():
def test_aggregate_news_by_symbol(aggregator):
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
news_items = [
{"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now},
{
@@ -64,7 +66,7 @@ def test_aggregate_news_by_symbol(aggregator):
def test_aggregate_empty(aggregator):
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert aggregator.aggregate([], now) == {}
diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py
index 25fc371..e00ffa6 100644
--- a/shared/tests/test_sentiment_models.py
+++ b/shared/tests/test_sentiment_models.py
@@ -1,16 +1,16 @@
"""Tests for news and sentiment models."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from shared.models import NewsCategory, NewsItem, OrderSide
-from shared.sentiment_models import SymbolScore, MarketSentiment, SelectedStock, Candidate
+from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore
def test_news_item_defaults():
item = NewsItem(
source="finnhub",
headline="Test headline",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.5,
category=NewsCategory.MACRO,
)
@@ -25,7 +25,7 @@ def test_news_item_with_symbols():
item = NewsItem(
source="rss",
headline="AAPL earnings beat",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.8,
category=NewsCategory.EARNINGS,
symbols=["AAPL"],
@@ -52,7 +52,7 @@ def test_symbol_score():
policy_score=0.0,
filing_score=0.2,
composite=0.3,
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
assert score.symbol == "AAPL"
assert score.composite == 0.3
@@ -65,7 +65,7 @@ def test_market_sentiment():
vix=32.5,
fed_stance="hawkish",
market_regime="risk_off",
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
assert ms.market_regime == "risk_off"
assert ms.vix == 32.5
@@ -77,7 +77,7 @@ def test_market_sentiment_no_vix():
fear_greed_label="Neutral",
fed_stance="neutral",
market_regime="neutral",
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
assert ms.vix is None
diff --git a/tests/edge_cases/test_empty_data.py b/tests/edge_cases/test_empty_data.py
index bfefc95..876a640 100644
--- a/tests/edge_cases/test_empty_data.py
+++ b/tests/edge_cases/test_empty_data.py
@@ -12,13 +12,14 @@ sys.path.insert(
0, str(Path(__file__).resolve().parents[2] / "services" / "portfolio-manager" / "src")
)
-from shared.models import Signal, OrderSide
from backtester.engine import BacktestEngine
from backtester.metrics import compute_detailed_metrics
-from portfolio_manager.portfolio import PortfolioTracker
from order_executor.risk_manager import RiskManager
+from portfolio_manager.portfolio import PortfolioTracker
from strategies.rsi_strategy import RsiStrategy
+from shared.models import OrderSide, Signal
+
class TestBacktestEngineEmptyCandles:
"""BacktestEngine.run([]) should return valid result with 0 trades."""
diff --git a/tests/edge_cases/test_extreme_values.py b/tests/edge_cases/test_extreme_values.py
index b375d5e..8ec3b77 100644
--- a/tests/edge_cases/test_extreme_values.py
+++ b/tests/edge_cases/test_extreme_values.py
@@ -1,7 +1,7 @@
"""Tests for extreme value edge cases."""
import sys
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from pathlib import Path
@@ -9,19 +9,20 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strat
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "order-executor" / "src"))
-from shared.models import Candle, Signal, OrderSide
-from strategies.rsi_strategy import RsiStrategy
-from strategies.vwap_strategy import VwapStrategy
-from strategies.bollinger_strategy import BollingerStrategy
from backtester.engine import BacktestEngine
from backtester.simulator import OrderSimulator
from order_executor.risk_manager import RiskManager
+from strategies.bollinger_strategy import BollingerStrategy
+from strategies.rsi_strategy import RsiStrategy
+from strategies.vwap_strategy import VwapStrategy
+
+from shared.models import Candle, OrderSide, Signal
def _candle(close: str, volume: str = "1000", idx: int = 0) -> Candle:
from datetime import timedelta
- base = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ base = datetime(2025, 1, 1, tzinfo=UTC)
return Candle(
symbol="AAPL",
timeframe="1h",
diff --git a/tests/edge_cases/test_strategy_reset.py b/tests/edge_cases/test_strategy_reset.py
index 6e9b956..13ed4da 100644
--- a/tests/edge_cases/test_strategy_reset.py
+++ b/tests/edge_cases/test_strategy_reset.py
@@ -1,21 +1,22 @@
"""Tests that strategy reset() properly clears internal state."""
import sys
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src"))
-from shared.models import Candle
-from strategies.rsi_strategy import RsiStrategy
-from strategies.grid_strategy import GridStrategy
-from strategies.macd_strategy import MacdStrategy
from strategies.bollinger_strategy import BollingerStrategy
from strategies.ema_crossover_strategy import EmaCrossoverStrategy
-from strategies.vwap_strategy import VwapStrategy
+from strategies.grid_strategy import GridStrategy
+from strategies.macd_strategy import MacdStrategy
+from strategies.rsi_strategy import RsiStrategy
from strategies.volume_profile_strategy import VolumeProfileStrategy
+from strategies.vwap_strategy import VwapStrategy
+
+from shared.models import Candle
def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]:
@@ -28,7 +29,7 @@ def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]:
Candle(
symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, i % 24, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, i % 24, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price + 1)),
low=Decimal(str(price - 1)),
@@ -57,7 +58,7 @@ class TestRsiReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
@@ -71,7 +72,7 @@ class TestGridReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
@@ -84,7 +85,7 @@ class TestMacdReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
@@ -97,7 +98,7 @@ class TestBollingerReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
@@ -110,7 +111,7 @@ class TestEmaCrossoverReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
@@ -123,7 +124,7 @@ class TestVwapReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
@@ -136,6 +137,6 @@ class TestVolumeProfileReset:
strategy.reset()
signals2 = _collect_signals(strategy, candles)
assert len(signals1) == len(signals2)
- for s1, s2 in zip(signals1, signals2):
+ for s1, s2 in zip(signals1, signals2, strict=True):
assert s1.side == s2.side
assert s1.price == s2.price
diff --git a/tests/edge_cases/test_zero_volume.py b/tests/edge_cases/test_zero_volume.py
index ba2c133..df247cc 100644
--- a/tests/edge_cases/test_zero_volume.py
+++ b/tests/edge_cases/test_zero_volume.py
@@ -1,21 +1,22 @@
"""Tests for strategies handling zero-volume candles gracefully."""
import sys
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src"))
-from shared.models import Candle
-from strategies.vwap_strategy import VwapStrategy
-from strategies.volume_profile_strategy import VolumeProfileStrategy
from strategies.rsi_strategy import RsiStrategy
+from strategies.volume_profile_strategy import VolumeProfileStrategy
+from strategies.vwap_strategy import VwapStrategy
+
+from shared.models import Candle
def _candle(close: str, volume: str = "0", idx: int = 0) -> Candle:
- base = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ base = datetime(2025, 1, 1, tzinfo=UTC)
from datetime import timedelta
return Candle(
diff --git a/tests/integration/test_backtest_end_to_end.py b/tests/integration/test_backtest_end_to_end.py
index 4cc0b12..fbc0a24 100644
--- a/tests/integration/test_backtest_end_to_end.py
+++ b/tests/integration/test_backtest_end_to_end.py
@@ -9,19 +9,20 @@ sys.path.insert(
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src"))
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
-from datetime import datetime, timedelta, timezone
-from shared.models import Candle
from backtester.engine import BacktestEngine
+from shared.models import Candle
+
def _generate_candles(prices: list[float], symbol="AAPL") -> list[Candle]:
return [
Candle(
symbol=symbol,
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i),
open=Decimal(str(p)),
high=Decimal(str(p + 100)),
low=Decimal(str(p - 100)),
diff --git a/tests/integration/test_order_execution_flow.py b/tests/integration/test_order_execution_flow.py
index dcbc498..2beb388 100644
--- a/tests/integration/test_order_execution_flow.py
+++ b/tests/integration/test_order_execution_flow.py
@@ -5,14 +5,15 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "order-executor" / "src"))
-import pytest
from decimal import Decimal
from unittest.mock import AsyncMock
-from shared.models import Signal, OrderSide, OrderStatus
+import pytest
from order_executor.executor import OrderExecutor
from order_executor.risk_manager import RiskManager
+from shared.models import OrderSide, OrderStatus, Signal
+
@pytest.mark.asyncio
async def test_signal_to_order_flow():
diff --git a/tests/integration/test_portfolio_tracking_flow.py b/tests/integration/test_portfolio_tracking_flow.py
index b20275a..d91a265 100644
--- a/tests/integration/test_portfolio_tracking_flow.py
+++ b/tests/integration/test_portfolio_tracking_flow.py
@@ -9,9 +9,10 @@ sys.path.insert(
from decimal import Decimal
-from shared.models import Order, OrderSide, OrderType, OrderStatus
from portfolio_manager.portfolio import PortfolioTracker
+from shared.models import Order, OrderSide, OrderStatus, OrderType
+
def test_portfolio_tracks_buy_sell_cycle():
"""Buy then sell should update position and reset on full sell."""
diff --git a/tests/integration/test_strategy_signal_flow.py b/tests/integration/test_strategy_signal_flow.py
index 6b048fb..3f7ec35 100644
--- a/tests/integration/test_strategy_signal_flow.py
+++ b/tests/integration/test_strategy_signal_flow.py
@@ -8,15 +8,16 @@ sys.path.insert(
)
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine"))
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock
-from shared.models import Candle
-from shared.events import CandleEvent
+import pytest
from strategy_engine.engine import StrategyEngine
+from shared.events import CandleEvent
+from shared.models import Candle
+
@pytest.fixture
def candles():
@@ -28,7 +29,7 @@ def candles():
Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, i, 0, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, i, 0, tzinfo=UTC),
open=price,
high=price + 1,
low=price - 1,