diff options
151 files changed, 4004 insertions, 986 deletions
diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8daace4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,18 @@ +__pycache__ +*.pyc +*.pyo +.git +.github +.venv +.env +.env.* +!.env.example +tests/ +docs/ +*.md +.ruff_cache +.pytest_cache +.mypy_cache +monitoring/ +scripts/ +cli/ diff --git a/.env.example b/.env.example index dcaf9a8..2cc65da 100644 --- a/.env.example +++ b/.env.example @@ -1,11 +1,23 @@ -# Alpaca API (get keys from https://app.alpaca.markets) +# === SECRETS (keep secure, do not commit .env) === ALPACA_API_KEY= ALPACA_API_SECRET= -ALPACA_PAPER=true - -REDIS_URL=redis://localhost:6379 +POSTGRES_USER=trading +POSTGRES_PASSWORD=trading DATABASE_URL=postgresql+asyncpg://trading:trading@localhost:5432/trading +REDIS_URL=redis://localhost:6379 +TELEGRAM_BOT_TOKEN= +FINNHUB_API_KEY= +ANTHROPIC_API_KEY= +API_AUTH_TOKEN= +METRICS_AUTH_TOKEN= + +# === CONFIGURATION === +ALPACA_PAPER=true +DRY_RUN=true +POSTGRES_DB=trading LOG_LEVEL=INFO +LOG_FORMAT=json +HEALTH_PORT=8080 RISK_MAX_POSITION_SIZE=0.1 RISK_STOP_LOSS_PCT=5 RISK_DAILY_LOSS_LIMIT_PCT=10 @@ -13,27 +25,19 @@ RISK_TRAILING_STOP_PCT=0 RISK_MAX_OPEN_POSITIONS=10 RISK_VOLATILITY_LOOKBACK=20 RISK_VOLATILITY_SCALE=false -DRY_RUN=true -TELEGRAM_BOT_TOKEN= TELEGRAM_CHAT_ID= TELEGRAM_ENABLED=false -LOG_FORMAT=json -HEALTH_PORT=8080 -CIRCUIT_BREAKER_THRESHOLD=5 -CIRCUIT_BREAKER_TIMEOUT=60 -METRICS_AUTH_TOKEN= # News Collector -FINNHUB_API_KEY= NEWS_POLL_INTERVAL=300 SENTIMENT_AGGREGATE_INTERVAL=900 # Stock Selector -SELECTOR_CANDIDATES_TIME=15:00 -SELECTOR_FILTER_TIME=15:15 SELECTOR_FINAL_TIME=15:30 SELECTOR_MAX_PICKS=3 # LLM (for stock selector) -ANTHROPIC_API_KEY= ANTHROPIC_MODEL=claude-sonnet-4-20250514 + +# === API SECURITY === +CORS_ORIGINS=http://localhost:3000 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..c07541b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,74 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install ruff + - run: ruff check . + - run: ruff format --check . + + test: + runs-on: ubuntu-latest + services: + redis: + image: redis:7-alpine + ports: [6379:6379] + options: >- + --health-cmd "redis-cli ping" + --health-interval 5s + --health-timeout 3s + --health-retries 5 + postgres: + image: postgres:16-alpine + env: + POSTGRES_USER: trading + POSTGRES_PASSWORD: trading + POSTGRES_DB: trading + ports: [5432:5432] + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 3s + --health-retries 5 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: | + pip install -e shared/[dev] + pip install -e services/strategy-engine/[dev] + pip install -e services/data-collector/[dev] + pip install -e services/order-executor/[dev] + pip install -e services/portfolio-manager/[dev] + pip install -e services/news-collector/[dev] + pip install -e services/api/[dev] + pip install -e services/backtester/[dev] + pip install pytest-cov + - run: pytest -v --cov=shared/src --cov=services --cov-report=xml --cov-report=term-missing + env: + DATABASE_URL: postgresql+asyncpg://trading:trading@localhost:5432/trading + REDIS_URL: redis://localhost:6379 + - uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.xml + + docker: + runs-on: ubuntu-latest + needs: [lint, test] + if: github.ref == 'refs/heads/master' + steps: + - uses: actions/checkout@v4 + - run: docker compose build --quiet diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..6e33f57 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,106 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +US stock trading platform built as a Python microservices architecture. Uses Alpaca Markets API for market data and order execution. Services communicate via Redis Streams and persist to PostgreSQL. + +## Common Commands + +```bash +make infra # Start Redis + Postgres (required before running services/tests) +make up # Start all services via Docker Compose +make down # Stop all services +make test # Run all tests (pytest -v) +make lint # Lint check (ruff check + format check) +make format # Auto-fix lint + format +make migrate # Run DB migrations (alembic upgrade head, from shared/) +make migrate-new msg="description" # Create new migration +make ci # Full CI: install deps, lint, test, Docker build +make e2e # End-to-end tests +``` + +Run a single test file: `pytest services/strategy-engine/tests/test_rsi_strategy.py -v` + +## Architecture + +### Services (each in `services/<name>/`, each has its own Dockerfile) + +- **data-collector** (port 8080): Fetches stock bars from Alpaca, publishes `CandleEvent` to Redis stream `candles` +- **news-collector** (port 8084): Continuously collects news from 7 sources (Finnhub, RSS, SEC EDGAR, Truth Social, Reddit, Fear & Greed, Fed), runs sentiment aggregation every 15 min +- **strategy-engine** (port 8081): Consumes candle events, runs strategies, publishes `SignalEvent` to stream `signals`. Also runs the stock selector at 15:30 ET daily +- **order-executor** (port 8082): Consumes signals, runs risk checks, places orders via Alpaca, publishes `OrderEvent` to stream `orders` +- **portfolio-manager** (port 8083): Tracks positions, PnL, portfolio snapshots +- **api** (port 8000): FastAPI REST endpoint layer +- **backtester**: Offline backtesting engine with walk-forward analysis + +### Event Flow + +``` +Alpaca → data-collector → [candles stream] → strategy-engine → [signals stream] → order-executor → [orders stream] → portfolio-manager + +News sources → news-collector → [news stream] → sentiment aggregator → symbol_scores DB + ↓ + stock selector (15:30 ET) → [selected_stocks stream] → MOC strategy → signals +``` + +All inter-service events use `shared/src/shared/events.py` (CandleEvent, SignalEvent, OrderEvent, NewsEvent) serialized as JSON over Redis Streams via `shared/src/shared/broker.py` (RedisBroker). + +### Shared Library (`shared/`) + +Installed as editable package (`pip install -e shared/`). Contains: +- `models.py` — Pydantic domain models: Candle, Signal, Order, Position, NewsItem, NewsCategory +- `sentiment_models.py` — SymbolScore, MarketSentiment, SelectedStock, Candidate +- `sa_models.py` — SQLAlchemy ORM models (CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow, NewsItemRow, SymbolScoreRow, MarketSentimentRow, StockSelectionRow) +- `broker.py` — RedisBroker (async Redis Streams pub/sub with consumer groups) +- `db.py` — Database class (async SQLAlchemy 2.0), includes news/sentiment/selection CRUD methods +- `alpaca.py` — AlpacaClient (async aiohttp client for Alpaca Trading + Market Data APIs) +- `events.py` — Event types and serialization (CandleEvent, SignalEvent, OrderEvent, NewsEvent) +- `sentiment.py` — SentimentData (legacy gating) + SentimentAggregator (freshness-weighted composite scoring) +- `config.py`, `logging.py`, `metrics.py`, `notifier.py` (Telegram), `resilience.py`, `healthcheck.py` + +DB migrations live in `shared/alembic/`. + +### Strategy System (`services/strategy-engine/strategies/`) + +Strategies extend `BaseStrategy` (in `strategies/base.py`) and implement `on_candle()`, `configure()`, `warmup_period`. The plugin loader (`strategy_engine/plugin_loader.py`) auto-discovers `*.py` files in the strategies directory and loads YAML config from `strategies/config/<strategy_name>.yaml`. + +BaseStrategy provides optional filters (ADX regime, volume, ATR-based stops) via `_init_filters()` and `_apply_filters()`. + +### News-Driven Stock Selector (`services/strategy-engine/src/strategy_engine/stock_selector.py`) + +Dynamic stock selection for MOC (Market on Close) trading. Runs daily at 15:30 ET via `strategy-engine`: + +1. **Candidate Pool**: Top 20 by sentiment score + LLM-recommended stocks from today's news +2. **Technical Filter**: RSI 30-70, price > 20 EMA, volume > 50% average +3. **LLM Final Selection**: Claude picks 2-3 stocks with rationale + +Market gating: blocks all trades when Fear & Greed ≤ 20 or VIX > 30 (`risk_off` regime). + +### News Collector (`services/news-collector/`) + +7 collectors extending `BaseCollector` in `collectors/`: +- `finnhub.py` (5min), `rss.py` (10min), `reddit.py` (15min), `truth_social.py` (15min), `sec_edgar.py` (30min), `fear_greed.py` (1hr), `fed.py` (1hr) +- All use VADER (nltk) for sentiment scoring +- Provider abstraction via `BaseCollector` for future paid API swap (config change only) + +Sentiment aggregation (every 15min) computes per-symbol composite scores with freshness decay and category weights (policy 0.3, news 0.3, social 0.2, filing 0.2). + +### CLI (`cli/`) + +Click-based CLI installed as `trading` command. Depends on the shared library. + +## Tech Stack + +- Python 3.12+, async throughout (asyncio, aiohttp) +- Pydantic for models, SQLAlchemy 2.0 async ORM, Alembic for migrations +- Redis Streams for inter-service messaging +- PostgreSQL 16 for persistence +- Ruff for linting/formatting (line-length=100) +- pytest + pytest-asyncio (asyncio_mode="auto") +- Docker Compose for deployment; monitoring stack (Grafana/Prometheus/Loki) available via `--profile monitoring` + +## Environment + +Copy `.env.example` to `.env`. Key vars: `ALPACA_API_KEY`, `ALPACA_API_SECRET`, `ALPACA_PAPER=true`, `DRY_RUN=true`, `DATABASE_URL`, `REDIS_URL`, `FINNHUB_API_KEY`, `ANTHROPIC_API_KEY`. DRY_RUN=true simulates order fills without hitting Alpaca. Stock selector requires `ANTHROPIC_API_KEY` to be set. diff --git a/cli/src/trading_cli/commands/backtest.py b/cli/src/trading_cli/commands/backtest.py index 3876f1b..c17c61f 100644 --- a/cli/src/trading_cli/commands/backtest.py +++ b/cli/src/trading_cli/commands/backtest.py @@ -35,11 +35,12 @@ def backtest(): def run(strategy, symbol, timeframe, balance, output_format, file_path): """Run a backtest for a strategy.""" try: - from strategy_engine.plugin_loader import load_strategies from backtester.engine import BacktestEngine - from backtester.reporter import format_report, export_csv, export_json - from shared.db import Database + from backtester.reporter import export_csv, export_json, format_report + from strategy_engine.plugin_loader import load_strategies + from shared.config import Settings + from shared.db import Database from shared.models import Candle except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) @@ -58,7 +59,7 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): async def _run(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: candle_rows = await db.get_candles(symbol, timeframe, limit=500) @@ -66,20 +67,19 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): click.echo(f"Error: No candles found for {symbol} {timeframe}", err=True) sys.exit(1) - candles = [] - for row in reversed(candle_rows): # get_candles returns DESC, we need ASC - candles.append( - Candle( - symbol=row["symbol"], - timeframe=row["timeframe"], - open_time=row["open_time"], - open=row["open"], - high=row["high"], - low=row["low"], - close=row["close"], - volume=row["volume"], - ) + candles = [ + Candle( + symbol=row["symbol"], + timeframe=row["timeframe"], + open_time=row["open_time"], + open=row["open"], + high=row["high"], + low=row["low"], + close=row["close"], + volume=row["volume"], ) + for row in reversed(candle_rows) # get_candles returns DESC, we need ASC + ] engine = BacktestEngine(strat, Decimal(str(balance))) result = engine.run(candles) @@ -111,10 +111,11 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): def walk_forward(strategy, symbol, timeframe, balance, windows): """Run walk-forward analysis to detect overfitting.""" try: - from strategy_engine.plugin_loader import load_strategies from backtester.walk_forward import WalkForwardEngine - from shared.db import Database + from strategy_engine.plugin_loader import load_strategies + from shared.config import Settings + from shared.db import Database from shared.models import Candle except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) @@ -131,7 +132,7 @@ def walk_forward(strategy, symbol, timeframe, balance, windows): async def _run(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: rows = await db.get_candles(symbol, timeframe, limit=2000) diff --git a/cli/src/trading_cli/commands/data.py b/cli/src/trading_cli/commands/data.py index 1ecc15f..64639cf 100644 --- a/cli/src/trading_cli/commands/data.py +++ b/cli/src/trading_cli/commands/data.py @@ -1,5 +1,6 @@ import asyncio import sys +from datetime import UTC from pathlib import Path import click @@ -39,23 +40,23 @@ def history(symbol, timeframe, since, limit): """Download historical stock market data for a symbol.""" try: from shared.alpaca_client import AlpacaClient - from shared.db import Database from shared.config import Settings + from shared.db import Database except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _fetch(): - from datetime import datetime, timezone + from datetime import datetime settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() start = None if since: try: - start = datetime.fromisoformat(since).replace(tzinfo=timezone.utc) + start = datetime.fromisoformat(since).replace(tzinfo=UTC) except ValueError: click.echo( f"Error: Invalid date format '{since}'. Use ISO format (e.g. 2024-01-01).", @@ -64,8 +65,8 @@ def history(symbol, timeframe, since, limit): sys.exit(1) client = AlpacaClient( - api_key=settings.alpaca_api_key, - api_secret=settings.alpaca_api_secret, + api_key=settings.alpaca_api_key.get_secret_value(), + api_secret=settings.alpaca_api_secret.get_secret_value(), base_url=getattr(settings, "alpaca_base_url", "https://paper-api.alpaca.markets"), ) @@ -97,17 +98,18 @@ def history(symbol, timeframe, since, limit): def list_(): """List available data streams and symbols.""" try: - from shared.db import Database + from sqlalchemy import func, select + from shared.config import Settings + from shared.db import Database from shared.sa_models import CandleRow - from sqlalchemy import select, func except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _list(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: stmt = ( diff --git a/cli/src/trading_cli/commands/portfolio.py b/cli/src/trading_cli/commands/portfolio.py index ad9a6b4..fd3ebd6 100644 --- a/cli/src/trading_cli/commands/portfolio.py +++ b/cli/src/trading_cli/commands/portfolio.py @@ -1,6 +1,6 @@ import asyncio import sys -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import click from rich.console import Console @@ -17,17 +17,18 @@ def portfolio(): def show(): """Show the current portfolio holdings and balances.""" try: - from shared.db import Database + from sqlalchemy import select + from shared.config import Settings + from shared.db import Database from shared.sa_models import PositionRow - from sqlalchemy import select except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _show(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: async with db.get_session() as session: @@ -71,20 +72,21 @@ def show(): def history(days): """Show PnL history for the portfolio.""" try: - from shared.db import Database + from sqlalchemy import select + from shared.config import Settings + from shared.db import Database from shared.sa_models import PortfolioSnapshotRow - from sqlalchemy import select except ImportError as e: click.echo(f"Error: Could not import required modules: {e}", err=True) sys.exit(1) async def _history(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: - since = datetime.now(timezone.utc) - timedelta(days=days) + since = datetime.now(UTC) - timedelta(days=days) stmt = ( select(PortfolioSnapshotRow) .where(PortfolioSnapshotRow.snapshot_at >= since) diff --git a/cli/src/trading_cli/commands/service.py b/cli/src/trading_cli/commands/service.py index d01eaae..6d02f14 100644 --- a/cli/src/trading_cli/commands/service.py +++ b/cli/src/trading_cli/commands/service.py @@ -1,4 +1,5 @@ import subprocess + import click diff --git a/cli/src/trading_cli/main.py b/cli/src/trading_cli/main.py index 1129bdd..0ed2307 100644 --- a/cli/src/trading_cli/main.py +++ b/cli/src/trading_cli/main.py @@ -1,10 +1,11 @@ import click -from trading_cli.commands.data import data -from trading_cli.commands.trade import trade + from trading_cli.commands.backtest import backtest +from trading_cli.commands.data import data from trading_cli.commands.portfolio import portfolio -from trading_cli.commands.strategy import strategy from trading_cli.commands.service import service +from trading_cli.commands.strategy import strategy +from trading_cli.commands.trade import trade @click.group() diff --git a/cli/tests/test_cli_strategy.py b/cli/tests/test_cli_strategy.py index cf3057b..75ba4df 100644 --- a/cli/tests/test_cli_strategy.py +++ b/cli/tests/test_cli_strategy.py @@ -1,6 +1,7 @@ """Tests for strategy CLI commands.""" -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + from click.testing import CliRunner from trading_cli.main import cli diff --git a/docker-compose.yml b/docker-compose.yml index 63630ff..60462ec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,15 +10,21 @@ services: interval: 5s timeout: 3s retries: 5 + networks: [internal] + deploy: + resources: + limits: + memory: 256M + cpus: '0.5' postgres: image: postgres:16-alpine ports: - "5432:5432" environment: - POSTGRES_USER: trading - POSTGRES_PASSWORD: trading - POSTGRES_DB: trading + POSTGRES_USER: ${POSTGRES_USER:-trading} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading} + POSTGRES_DB: ${POSTGRES_DB:-trading} volumes: - postgres_data:/var/lib/postgresql/data healthcheck: @@ -26,6 +32,12 @@ services: interval: 5s timeout: 3s retries: 5 + networks: [internal] + deploy: + resources: + limits: + memory: 256M + cpus: '0.5' data-collector: build: @@ -45,6 +57,12 @@ services: timeout: 5s retries: 3 restart: unless-stopped + networks: [internal] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' strategy-engine: build: @@ -64,6 +82,12 @@ services: timeout: 5s retries: 3 restart: unless-stopped + networks: [internal] + deploy: + resources: + limits: + memory: 1G + cpus: '1.0' order-executor: build: @@ -83,6 +107,12 @@ services: timeout: 5s retries: 3 restart: unless-stopped + networks: [internal] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' portfolio-manager: build: @@ -102,6 +132,12 @@ services: timeout: 5s retries: 3 restart: unless-stopped + networks: [internal] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' api: build: @@ -121,6 +157,12 @@ services: timeout: 5s retries: 3 restart: unless-stopped + networks: [internal] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' news-collector: build: @@ -140,6 +182,12 @@ services: timeout: 5s retries: 3 restart: unless-stopped + networks: [internal] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' loki: image: grafana/loki:latest @@ -150,6 +198,12 @@ services: - ./monitoring/loki/loki-config.yaml:/etc/loki/local-config.yaml - loki_data:/loki command: -config.file=/etc/loki/local-config.yaml + networks: [monitoring] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' promtail: image: grafana/promtail:latest @@ -160,6 +214,12 @@ services: command: -config.file=/etc/promtail/config.yaml depends_on: - loki + networks: [monitoring] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' prometheus: image: prom/prometheus:latest @@ -168,11 +228,18 @@ services: - "9090:9090" volumes: - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml + - ./monitoring/prometheus/alert_rules.yml:/etc/prometheus/alert_rules.yml depends_on: - data-collector - strategy-engine - order-executor - portfolio-manager + networks: [internal, monitoring] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' grafana: image: grafana/grafana:latest @@ -187,8 +254,20 @@ services: - ./monitoring/grafana/dashboards:/var/lib/grafana/dashboards depends_on: - prometheus + networks: [internal, monitoring] + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' volumes: redis_data: postgres_data: loki_data: + +networks: + internal: + driver: bridge + monitoring: + driver: bridge diff --git a/docs/superpowers/plans/2026-04-02-platform-upgrade.md b/docs/superpowers/plans/2026-04-02-platform-upgrade.md new file mode 100644 index 0000000..c28d287 --- /dev/null +++ b/docs/superpowers/plans/2026-04-02-platform-upgrade.md @@ -0,0 +1,1991 @@ +# Platform Upgrade Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Upgrade the trading platform across 5 phases: shared library hardening, infrastructure improvements, service-level fixes, API security, and operational maturity. + +**Architecture:** Bottom-up approach — harden the shared library first (resilience, DB pooling, Redis resilience, config validation), then improve infrastructure (Docker, DB indexes), then fix all services (graceful shutdown, exception handling), then add API security (auth, CORS, rate limiting), and finally improve operations (CI/CD, linting, alerting). + +**Tech Stack:** Python 3.12, asyncio, tenacity, SQLAlchemy 2.0 async, Redis Streams, FastAPI, slowapi, Ruff, GitHub Actions, Prometheus Alertmanager + +--- + +## File Structure + +### New Files +- `shared/src/shared/resilience.py` — Retry decorator, circuit breaker, timeout wrapper +- `shared/tests/test_resilience.py` — Tests for resilience module +- `shared/alembic/versions/003_add_missing_indexes.py` — DB index migration +- `.dockerignore` — Docker build exclusions +- `services/api/src/trading_api/dependencies/auth.py` — Bearer token auth dependency +- `.github/workflows/ci.yml` — GitHub Actions CI pipeline +- `monitoring/prometheus/alert_rules.yml` — Prometheus alerting rules + +### Modified Files +- `shared/src/shared/db.py` — Add connection pool config +- `shared/src/shared/broker.py` — Add Redis resilience +- `shared/src/shared/config.py` — Add validators, SecretStr, new fields +- `shared/pyproject.toml` — Pin deps, add tenacity +- `pyproject.toml` — Enhanced ruff rules, pytest-cov +- `services/strategy-engine/src/strategy_engine/stock_selector.py` — Fix bug, deduplicate, session reuse +- `services/*/src/*/main.py` — Signal handlers, exception specialization (all 6 services) +- `services/*/Dockerfile` — Multi-stage builds, non-root user (all 7 Dockerfiles) +- `services/api/pyproject.toml` — Add slowapi +- `services/api/src/trading_api/main.py` — CORS, auth, rate limiting +- `services/api/src/trading_api/routers/*.py` — Input validation, response models +- `docker-compose.yml` — Remove hardcoded creds, add resource limits, networks +- `.env.example` — Add new fields, mark secrets +- `monitoring/prometheus.yml` — Reference alert rules + +--- + +## Phase 1: Shared Library Hardening + +### Task 1: Implement Resilience Module + +**Files:** +- Create: `shared/src/shared/resilience.py` +- Create: `shared/tests/test_resilience.py` +- Modify: `shared/pyproject.toml:6-18` + +- [ ] **Step 1: Add tenacity dependency to shared/pyproject.toml** + +In `shared/pyproject.toml`, add `tenacity` to the dependencies list: + +```python +dependencies = [ + "pydantic>=2.8,<3", + "pydantic-settings>=2.0,<3", + "redis>=5.0,<6", + "asyncpg>=0.29,<1", + "sqlalchemy[asyncio]>=2.0,<3", + "alembic>=1.13,<2", + "structlog>=24.0,<25", + "prometheus-client>=0.20,<1", + "pyyaml>=6.0,<7", + "aiohttp>=3.9,<4", + "rich>=13.0,<14", + "tenacity>=8.2,<10", +] +``` + +Note: This also pins all existing dependencies with upper bounds. + +- [ ] **Step 2: Write failing tests for retry_async** + +Create `shared/tests/test_resilience.py`: + +```python +"""Tests for the resilience module.""" + +import asyncio + +import pytest + +from shared.resilience import retry_async, CircuitBreaker, async_timeout + + +class TestRetryAsync: + async def test_succeeds_without_retry(self): + call_count = 0 + + @retry_async(max_retries=3) + async def succeed(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await succeed() + assert result == "ok" + assert call_count == 1 + + async def test_retries_on_failure_then_succeeds(self): + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fail_twice(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ConnectionError("fail") + return "ok" + + result = await fail_twice() + assert result == "ok" + assert call_count == 3 + + async def test_raises_after_max_retries(self): + @retry_async(max_retries=2, base_delay=0.01) + async def always_fail(): + raise ConnectionError("fail") + + with pytest.raises(ConnectionError): + await always_fail() + + async def test_no_retry_on_excluded_exception(self): + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01, exclude=(ValueError,)) + async def raise_value_error(): + nonlocal call_count + call_count += 1 + raise ValueError("bad input") + + with pytest.raises(ValueError): + await raise_value_error() + assert call_count == 1 + + +class TestCircuitBreaker: + async def test_closed_allows_calls(self): + cb = CircuitBreaker(failure_threshold=3, cooldown=0.1) + + async def succeed(): + return "ok" + + result = await cb.call(succeed) + assert result == "ok" + + async def test_opens_after_threshold(self): + cb = CircuitBreaker(failure_threshold=2, cooldown=60) + + async def fail(): + raise ConnectionError("fail") + + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(fail) + + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + async def test_half_open_after_cooldown(self): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + call_count = 0 + + async def fail_then_succeed(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise ConnectionError("fail") + return "recovered" + + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(fail_then_succeed) + + # Wait for cooldown + await asyncio.sleep(0.1) + + # Should allow one call (half-open) + result = await cb.call(fail_then_succeed) + assert result == "recovered" + + +class TestAsyncTimeout: + async def test_completes_within_timeout(self): + async with async_timeout(1.0): + await asyncio.sleep(0.01) + + async def test_raises_on_timeout(self): + with pytest.raises(asyncio.TimeoutError): + async with async_timeout(0.01): + await asyncio.sleep(1.0) +``` + +- [ ] **Step 3: Run tests to verify they fail** + +Run: `pytest shared/tests/test_resilience.py -v` +Expected: FAIL with `ImportError: cannot import name 'retry_async' from 'shared.resilience'` + +- [ ] **Step 4: Implement resilience module** + +Write `shared/src/shared/resilience.py`: + +```python +"""Resilience utilities: retry, circuit breaker, timeout.""" + +import asyncio +import functools +import logging +import time +from contextlib import asynccontextmanager + +logger = logging.getLogger(__name__) + + +def retry_async( + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + exclude: tuple[type[Exception], ...] = (), +): + """Decorator for async functions with exponential backoff + jitter. + + Args: + max_retries: Maximum number of retry attempts. + base_delay: Initial delay in seconds between retries. + max_delay: Maximum delay cap in seconds. + exclude: Exception types that should NOT be retried. + """ + + def decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + last_exc = None + for attempt in range(max_retries + 1): + try: + return await func(*args, **kwargs) + except exclude: + raise + except Exception as exc: + last_exc = exc + if attempt == max_retries: + raise + delay = min(base_delay * (2**attempt), max_delay) + # Add jitter: 50-100% of delay + import random + + delay = delay * (0.5 + random.random() * 0.5) + logger.warning( + "retry attempt=%d/%d delay=%.2fs error=%s func=%s", + attempt + 1, + max_retries, + delay, + str(exc), + func.__name__, + ) + await asyncio.sleep(delay) + raise last_exc # Should not reach here, but just in case + + return wrapper + + return decorator + + +class CircuitBreaker: + """Circuit breaker: opens after consecutive failures, auto-recovers after cooldown.""" + + def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None: + self._failure_threshold = failure_threshold + self._cooldown = cooldown + self._failure_count = 0 + self._last_failure_time: float = 0 + self._state = "closed" # closed, open, half_open + + async def call(self, func, *args, **kwargs): + if self._state == "open": + if time.monotonic() - self._last_failure_time >= self._cooldown: + self._state = "half_open" + else: + raise RuntimeError("Circuit breaker is open") + + try: + result = await func(*args, **kwargs) + self._failure_count = 0 + self._state = "closed" + return result + except Exception: + self._failure_count += 1 + self._last_failure_time = time.monotonic() + if self._failure_count >= self._failure_threshold: + self._state = "open" + logger.error( + "circuit_breaker_opened failures=%d cooldown=%.0fs", + self._failure_count, + self._cooldown, + ) + raise + + +@asynccontextmanager +async def async_timeout(seconds: float): + """Async context manager that raises TimeoutError after given seconds.""" + try: + async with asyncio.timeout(seconds): + yield + except TimeoutError: + raise asyncio.TimeoutError(f"Operation timed out after {seconds}s") +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `pytest shared/tests/test_resilience.py -v` +Expected: All 8 tests PASS + +- [ ] **Step 6: Commit** + +```bash +git add shared/src/shared/resilience.py shared/tests/test_resilience.py shared/pyproject.toml +git commit -m "feat: implement resilience module with retry, circuit breaker, timeout" +``` + +--- + +### Task 2: Add DB Connection Pooling + +**Files:** +- Modify: `shared/src/shared/db.py:39-44` +- Modify: `shared/src/shared/config.py:10-11` +- Modify: `shared/tests/test_db.py` (add pool config test) + +- [ ] **Step 1: Write failing test for pool config** + +Add to `shared/tests/test_db.py`: + +```python +async def test_connect_configures_pool(tmp_path): + """Engine should be created with pool configuration.""" + db = Database("sqlite+aiosqlite:///:memory:") + await db.connect() + engine = db._engine + pool = engine.pool + # aiosqlite uses StaticPool so we just verify connect works + assert engine is not None + await db.close() +``` + +- [ ] **Step 2: Add pool settings to config.py** + +In `shared/src/shared/config.py`, add after line 11 (`database_url`): + +```python + db_pool_size: int = 20 + db_max_overflow: int = 10 + db_pool_recycle: int = 3600 +``` + +- [ ] **Step 3: Update Database.connect() with pool parameters** + +In `shared/src/shared/db.py`, replace line 41: + +```python + self._engine = create_async_engine(self._database_url) +``` + +with: + +```python + self._engine = create_async_engine( + self._database_url, + pool_pre_ping=True, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + ) +``` + +Update the `connect` method signature to accept pool params: + +```python + async def connect( + self, + pool_size: int = 20, + max_overflow: int = 10, + pool_recycle: int = 3600, + ) -> None: + """Create the async engine, session factory, and all tables.""" + if self._database_url.startswith("sqlite"): + self._engine = create_async_engine(self._database_url) + else: + self._engine = create_async_engine( + self._database_url, + pool_pre_ping=True, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + ) + self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) +``` + +- [ ] **Step 4: Run tests** + +Run: `pytest shared/tests/test_db.py -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add shared/src/shared/db.py shared/src/shared/config.py shared/tests/test_db.py +git commit -m "feat: add DB connection pooling with configurable pool_size, overflow, recycle" +``` + +--- + +### Task 3: Add Redis Resilience + +**Files:** +- Modify: `shared/src/shared/broker.py:1-13,15-18,102-104` +- Create: `shared/tests/test_broker_resilience.py` + +- [ ] **Step 1: Write failing tests for Redis resilience** + +Create `shared/tests/test_broker_resilience.py`: + +```python +"""Tests for Redis broker resilience features.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from shared.broker import RedisBroker + + +class TestBrokerResilience: + async def test_publish_retries_on_connection_error(self): + broker = RedisBroker.__new__(RedisBroker) + mock_redis = AsyncMock() + call_count = 0 + + async def xadd_failing(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ConnectionError("Redis connection lost") + return "msg-id" + + mock_redis.xadd = xadd_failing + broker._redis = mock_redis + + await broker.publish("test-stream", {"key": "value"}) + assert call_count == 3 + + async def test_ping_retries_on_timeout(self): + broker = RedisBroker.__new__(RedisBroker) + mock_redis = AsyncMock() + call_count = 0 + + async def ping_failing(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise TimeoutError("timeout") + return True + + mock_redis.ping = ping_failing + broker._redis = mock_redis + + result = await broker.ping() + assert result is True + assert call_count == 2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest shared/tests/test_broker_resilience.py -v` +Expected: FAIL (publish doesn't retry) + +- [ ] **Step 3: Add resilience to broker.py** + +Replace `shared/src/shared/broker.py`: + +```python +"""Redis Streams broker for the trading platform.""" + +import json +import logging +from typing import Any + +import redis.asyncio + +from shared.resilience import retry_async + +logger = logging.getLogger(__name__) + + +class RedisBroker: + """Async Redis Streams broker for publishing and reading events.""" + + def __init__(self, redis_url: str) -> None: + self._redis = redis.asyncio.from_url( + redis_url, + socket_keepalive=True, + health_check_interval=30, + retry_on_timeout=True, + ) + + @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,)) + async def publish(self, stream: str, data: dict[str, Any]) -> None: + """Publish a message to a Redis stream.""" + payload = json.dumps(data) + await self._redis.xadd(stream, {"payload": payload}) + + async def ensure_group(self, stream: str, group: str) -> None: + """Create a consumer group if it doesn't exist.""" + try: + await self._redis.xgroup_create(stream, group, id="0", mkstream=True) + except redis.ResponseError as e: + if "BUSYGROUP" not in str(e): + raise + + @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,)) + async def read_group( + self, + stream: str, + group: str, + consumer: str, + count: int = 10, + block: int = 0, + ) -> list[tuple[str, dict[str, Any]]]: + """Read messages from a consumer group. Returns list of (message_id, data).""" + results = await self._redis.xreadgroup( + group, consumer, {stream: ">"}, count=count, block=block + ) + messages = [] + if results: + for _stream, entries in results: + for msg_id, fields in entries: + payload = fields.get(b"payload") or fields.get("payload") + if payload: + if isinstance(payload, bytes): + payload = payload.decode() + if isinstance(msg_id, bytes): + msg_id = msg_id.decode() + messages.append((msg_id, json.loads(payload))) + return messages + + async def ack(self, stream: str, group: str, *msg_ids: str) -> None: + """Acknowledge messages in a consumer group.""" + if msg_ids: + await self._redis.xack(stream, group, *msg_ids) + + async def read_pending( + self, + stream: str, + group: str, + consumer: str, + count: int = 10, + ) -> list[tuple[str, dict[str, Any]]]: + """Read pending (unacknowledged) messages for this consumer.""" + results = await self._redis.xreadgroup(group, consumer, {stream: "0"}, count=count) + messages = [] + if results: + for _stream, entries in results: + for msg_id, fields in entries: + if not fields: + continue + payload = fields.get(b"payload") or fields.get("payload") + if payload: + if isinstance(payload, bytes): + payload = payload.decode() + if isinstance(msg_id, bytes): + msg_id = msg_id.decode() + messages.append((msg_id, json.loads(payload))) + return messages + + async def read( + self, + stream: str, + last_id: str = "$", + count: int = 10, + block: int = 0, + ) -> list[dict[str, Any]]: + """Read messages (original method, kept for backward compatibility).""" + results = await self._redis.xread({stream: last_id}, count=count, block=block) + messages = [] + if results: + for _stream, entries in results: + for _msg_id, fields in entries: + payload = fields.get(b"payload") or fields.get("payload") + if payload: + if isinstance(payload, bytes): + payload = payload.decode() + messages.append(json.loads(payload)) + return messages + + @retry_async(max_retries=2, base_delay=0.5) + async def ping(self) -> bool: + """Ping the Redis server; return True if reachable.""" + return await self._redis.ping() + + async def close(self) -> None: + """Close the Redis connection.""" + await self._redis.aclose() +``` + +- [ ] **Step 4: Run tests** + +Run: `pytest shared/tests/test_broker_resilience.py -v` +Expected: PASS + +Run: `pytest shared/tests/test_broker.py -v` +Expected: PASS (existing tests still work) + +- [ ] **Step 5: Commit** + +```bash +git add shared/src/shared/broker.py shared/tests/test_broker_resilience.py +git commit -m "feat: add retry and resilience to Redis broker with keepalive" +``` + +--- + +### Task 4: Config Validation & SecretStr + +**Files:** +- Modify: `shared/src/shared/config.py` +- Create: `shared/tests/test_config_validation.py` + +- [ ] **Step 1: Write failing tests for config validation** + +Create `shared/tests/test_config_validation.py`: + +```python +"""Tests for config validation.""" + +import pytest +from pydantic import ValidationError + +from shared.config import Settings + + +class TestConfigValidation: + def test_valid_defaults(self): + settings = Settings() + assert settings.risk_max_position_size == 0.1 + + def test_invalid_position_size(self): + with pytest.raises(ValidationError, match="risk_max_position_size"): + Settings(risk_max_position_size=-0.1) + + def test_invalid_health_port(self): + with pytest.raises(ValidationError, match="health_port"): + Settings(health_port=80) + + def test_invalid_log_level(self): + with pytest.raises(ValidationError, match="log_level"): + Settings(log_level="INVALID") + + def test_secret_fields_masked(self): + settings = Settings(alpaca_api_key="my-secret-key") + assert "my-secret-key" not in repr(settings) + assert settings.alpaca_api_key.get_secret_value() == "my-secret-key" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest shared/tests/test_config_validation.py -v` +Expected: FAIL + +- [ ] **Step 3: Update config.py with validators and SecretStr** + +Replace `shared/src/shared/config.py`: + +```python +"""Shared configuration settings for the trading platform.""" + +from pydantic import SecretStr, field_validator +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + # Alpaca + alpaca_api_key: SecretStr = SecretStr("") + alpaca_api_secret: SecretStr = SecretStr("") + alpaca_paper: bool = True + # Infrastructure + redis_url: SecretStr = SecretStr("redis://localhost:6379") + database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading") + # DB pool + db_pool_size: int = 20 + db_max_overflow: int = 10 + db_pool_recycle: int = 3600 + # Logging + log_level: str = "INFO" + log_format: str = "json" + # Health + health_port: int = 8080 + metrics_auth_token: str = "" + # Risk + risk_max_position_size: float = 0.1 + risk_stop_loss_pct: float = 5.0 + risk_daily_loss_limit_pct: float = 10.0 + risk_trailing_stop_pct: float = 0.0 + risk_max_open_positions: int = 10 + risk_volatility_lookback: int = 20 + risk_volatility_scale: bool = False + risk_max_portfolio_exposure: float = 0.8 + risk_max_correlated_exposure: float = 0.5 + risk_correlation_threshold: float = 0.7 + risk_var_confidence: float = 0.95 + risk_var_limit_pct: float = 5.0 + risk_drawdown_reduction_threshold: float = 0.1 + risk_drawdown_halt_threshold: float = 0.2 + risk_max_consecutive_losses: int = 5 + risk_loss_pause_minutes: int = 60 + dry_run: bool = True + # Telegram + telegram_bot_token: SecretStr = SecretStr("") + telegram_chat_id: str = "" + telegram_enabled: bool = False + # News + finnhub_api_key: SecretStr = SecretStr("") + news_poll_interval: int = 300 + sentiment_aggregate_interval: int = 900 + # Stock selector + selector_final_time: str = "15:30" + selector_max_picks: int = 3 + # LLM + anthropic_api_key: SecretStr = SecretStr("") + anthropic_model: str = "claude-sonnet-4-20250514" + # API security + api_auth_token: SecretStr = SecretStr("") + cors_origins: str = "http://localhost:3000" + + model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} + + @field_validator("risk_max_position_size") + @classmethod + def validate_position_size(cls, v: float) -> float: + if v <= 0 or v > 1: + raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)") + return v + + @field_validator("health_port") + @classmethod + def validate_health_port(cls, v: int) -> int: + if v < 1024 or v > 65535: + raise ValueError("health_port must be between 1024 and 65535") + return v + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + if v.upper() not in valid: + raise ValueError(f"log_level must be one of {valid}") + return v.upper() +``` + +- [ ] **Step 4: Update all consumers to use .get_secret_value()** + +Every place that reads `settings.alpaca_api_key` etc. must now call `.get_secret_value()`. Key files to update: + +**`shared/src/shared/alpaca.py`** — where AlpacaClient is instantiated (in each service main.py), change: +```python +# Before: +alpaca = AlpacaClient(cfg.alpaca_api_key, cfg.alpaca_api_secret, paper=cfg.alpaca_paper) +# After: +alpaca = AlpacaClient( + cfg.alpaca_api_key.get_secret_value(), + cfg.alpaca_api_secret.get_secret_value(), + paper=cfg.alpaca_paper, +) +``` + +**Each service main.py** — where `Database(cfg.database_url)` and `RedisBroker(cfg.redis_url)` are called: +```python +# Before: +db = Database(cfg.database_url) +broker = RedisBroker(cfg.redis_url) +# After: +db = Database(cfg.database_url.get_secret_value()) +broker = RedisBroker(cfg.redis_url.get_secret_value()) +``` + +**`shared/src/shared/notifier.py`** — where telegram_bot_token is used: +```python +# Change token access to .get_secret_value() +``` + +**`services/strategy-engine/src/strategy_engine/main.py`** — where anthropic_api_key is passed: +```python +# Before: +anthropic_api_key=cfg.anthropic_api_key, +# After: +anthropic_api_key=cfg.anthropic_api_key.get_secret_value(), +``` + +**`services/news-collector/src/news_collector/main.py`** — where finnhub_api_key is used: +```python +# Before: +cfg.finnhub_api_key +# After: +cfg.finnhub_api_key.get_secret_value() +``` + +- [ ] **Step 5: Run all tests** + +Run: `pytest shared/tests/test_config_validation.py -v` +Expected: PASS + +Run: `pytest -v` +Expected: All tests PASS (no regressions from SecretStr changes) + +- [ ] **Step 6: Commit** + +```bash +git add shared/src/shared/config.py shared/tests/test_config_validation.py +git add services/*/src/*/main.py shared/src/shared/notifier.py +git commit -m "feat: add config validation, SecretStr for secrets, API security fields" +``` + +--- + +### Task 5: Pin All Dependencies + +**Files:** +- Modify: `shared/pyproject.toml` (already done in Task 1) +- Modify: `services/strategy-engine/pyproject.toml` +- Modify: `services/backtester/pyproject.toml` +- Modify: `services/api/pyproject.toml` +- Modify: `services/news-collector/pyproject.toml` +- Modify: `services/data-collector/pyproject.toml` +- Modify: `services/order-executor/pyproject.toml` +- Modify: `services/portfolio-manager/pyproject.toml` + +- [ ] **Step 1: Pin service dependencies** + +`services/strategy-engine/pyproject.toml`: +```toml +dependencies = [ + "pandas>=2.1,<3", + "numpy>=1.26,<3", + "trading-shared", +] +``` + +`services/backtester/pyproject.toml`: +```toml +dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "rich>=13.0,<14", "trading-shared"] +``` + +`services/api/pyproject.toml`: +```toml +dependencies = [ + "fastapi>=0.110,<1", + "uvicorn>=0.27,<1", + "slowapi>=0.1.9,<1", + "trading-shared", +] +``` + +`services/news-collector/pyproject.toml`: +```toml +dependencies = [ + "trading-shared", + "feedparser>=6.0,<7", + "nltk>=3.8,<4", + "aiohttp>=3.9,<4", +] +``` + +`shared/pyproject.toml` optional deps: +```toml +[project.optional-dependencies] +dev = [ + "pytest>=8.0,<9", + "pytest-asyncio>=0.23,<1", + "ruff>=0.4,<1", +] +claude = [ + "anthropic>=0.40,<1", +] +``` + +- [ ] **Step 2: Verify installation works** + +Run: `pip install -e shared/ && pip install -e services/strategy-engine/ && pip install -e services/api/` +Expected: No errors + +- [ ] **Step 3: Commit** + +```bash +git add shared/pyproject.toml services/*/pyproject.toml +git commit -m "chore: pin all dependencies with upper bounds" +``` + +--- + +## Phase 2: Infrastructure Hardening + +### Task 6: Docker Secrets & Environment Cleanup + +**Files:** +- Modify: `docker-compose.yml:17-21` +- Modify: `.env.example` + +- [ ] **Step 1: Replace hardcoded Postgres credentials in docker-compose.yml** + +In `docker-compose.yml`, replace the postgres service environment: + +```yaml + postgres: + image: postgres:16-alpine + environment: + POSTGRES_USER: ${POSTGRES_USER:-trading} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-trading} + POSTGRES_DB: ${POSTGRES_DB:-trading} +``` + +- [ ] **Step 2: Update .env.example with secret annotations** + +Add to `.env.example`: + +```bash +# === SECRETS (keep secure, do not commit .env) === +ALPACA_API_KEY= +ALPACA_API_SECRET= +DATABASE_URL=postgresql+asyncpg://trading:trading@localhost:5432/trading +REDIS_URL=redis://localhost:6379 +TELEGRAM_BOT_TOKEN= +FINNHUB_API_KEY= +ANTHROPIC_API_KEY= +API_AUTH_TOKEN= +POSTGRES_USER=trading +POSTGRES_PASSWORD=trading +POSTGRES_DB=trading + +# === CONFIGURATION === +ALPACA_PAPER=true +DRY_RUN=true +LOG_LEVEL=INFO +LOG_FORMAT=json +HEALTH_PORT=8080 +# ... (keep existing config vars) + +# === API SECURITY === +CORS_ORIGINS=http://localhost:3000 +``` + +- [ ] **Step 3: Commit** + +```bash +git add docker-compose.yml .env.example +git commit -m "fix: move hardcoded postgres credentials to .env, annotate secrets" +``` + +--- + +### Task 7: Dockerfile Optimization + +**Files:** +- Create: `.dockerignore` +- Modify: All 7 Dockerfiles in `services/*/Dockerfile` + +- [ ] **Step 1: Create .dockerignore** + +Create `.dockerignore` at project root: + +``` +__pycache__ +*.pyc +*.pyo +.git +.github +.venv +.env +.env.* +!.env.example +tests/ +docs/ +*.md +.ruff_cache +.pytest_cache +.mypy_cache +monitoring/ +scripts/ +cli/ +``` + +- [ ] **Step 2: Update data-collector Dockerfile** + +Replace `services/data-collector/Dockerfile`: + +```dockerfile +FROM python:3.12-slim AS builder +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/data-collector/ services/data-collector/ +RUN pip install --no-cache-dir ./services/data-collector + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +ENV PYTHONPATH=/app +USER appuser +CMD ["python", "-m", "data_collector.main"] +``` + +- [ ] **Step 3: Update all other Dockerfiles with same pattern** + +Apply the same multi-stage + non-root pattern to: +- `services/strategy-engine/Dockerfile` (also copies strategies/) +- `services/order-executor/Dockerfile` +- `services/portfolio-manager/Dockerfile` +- `services/api/Dockerfile` (also copies strategies/, uses uvicorn CMD) +- `services/news-collector/Dockerfile` (also runs nltk download) +- `services/backtester/Dockerfile` (also copies strategies/) + +For **strategy-engine** Dockerfile: +```dockerfile +FROM python:3.12-slim AS builder +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/strategy-engine/ services/strategy-engine/ +RUN pip install --no-cache-dir ./services/strategy-engine + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY services/strategy-engine/strategies/ /app/strategies/ +ENV PYTHONPATH=/app +USER appuser +CMD ["python", "-m", "strategy_engine.main"] +``` + +For **news-collector** Dockerfile: +```dockerfile +FROM python:3.12-slim AS builder +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/news-collector/ services/news-collector/ +RUN pip install --no-cache-dir ./services/news-collector +RUN python -c "import nltk; nltk.download('vader_lexicon', download_dir='/usr/local/nltk_data')" + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY --from=builder /usr/local/nltk_data /usr/local/nltk_data +ENV PYTHONPATH=/app +USER appuser +CMD ["python", "-m", "news_collector.main"] +``` + +For **api** Dockerfile: +```dockerfile +FROM python:3.12-slim AS builder +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/api/ services/api/ +RUN pip install --no-cache-dir ./services/api +COPY services/strategy-engine/ services/strategy-engine/ +RUN pip install --no-cache-dir ./services/strategy-engine + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY services/strategy-engine/strategies/ /app/strategies/ +ENV PYTHONPATH=/app STRATEGIES_DIR=/app/strategies +USER appuser +CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000", "--timeout-graceful-shutdown", "30"] +``` + +For **order-executor**, **portfolio-manager**, **backtester** — same pattern as data-collector, adjusting the service name and CMD. + +- [ ] **Step 4: Verify Docker build works** + +Run: `docker compose build --quiet` +Expected: All images build successfully + +- [ ] **Step 5: Commit** + +```bash +git add .dockerignore services/*/Dockerfile +git commit -m "feat: optimize Dockerfiles with multi-stage builds, non-root user, .dockerignore" +``` + +--- + +### Task 8: Database Index Migration + +**Files:** +- Create: `shared/alembic/versions/003_add_missing_indexes.py` + +- [ ] **Step 1: Create migration file** + +Create `shared/alembic/versions/003_add_missing_indexes.py`: + +```python +"""Add missing indexes for common query patterns. + +Revision ID: 003 +Revises: 002 +""" + +from alembic import op + +revision = "003" +down_revision = "002" + + +def upgrade(): + op.create_index("idx_signals_symbol_created", "signals", ["symbol", "created_at"]) + op.create_index("idx_orders_symbol_status_created", "orders", ["symbol", "status", "created_at"]) + op.create_index("idx_trades_order_id", "trades", ["order_id"]) + op.create_index("idx_trades_symbol_traded", "trades", ["symbol", "traded_at"]) + op.create_index("idx_portfolio_snapshots_at", "portfolio_snapshots", ["snapshot_at"]) + op.create_index("idx_symbol_scores_symbol", "symbol_scores", ["symbol"], unique=True) + + +def downgrade(): + op.drop_index("idx_symbol_scores_symbol", table_name="symbol_scores") + op.drop_index("idx_portfolio_snapshots_at", table_name="portfolio_snapshots") + op.drop_index("idx_trades_symbol_traded", table_name="trades") + op.drop_index("idx_trades_order_id", table_name="trades") + op.drop_index("idx_orders_symbol_status_created", table_name="orders") + op.drop_index("idx_signals_symbol_created", table_name="signals") +``` + +- [ ] **Step 2: Verify migration runs (requires infra)** + +Run: `make infra && cd shared && alembic upgrade head` +Expected: Migration 003 applied successfully + +- [ ] **Step 3: Commit** + +```bash +git add shared/alembic/versions/003_add_missing_indexes.py +git commit -m "feat: add missing DB indexes for signals, orders, trades, snapshots" +``` + +--- + +### Task 9: Docker Compose Resource Limits & Networks + +**Files:** +- Modify: `docker-compose.yml` + +- [ ] **Step 1: Add networks and resource limits** + +Add to `docker-compose.yml` at bottom: + +```yaml +networks: + internal: + driver: bridge + monitoring: + driver: bridge +``` + +Add `networks: [internal]` to all application services (redis, postgres, data-collector, strategy-engine, order-executor, portfolio-manager, api, news-collector). + +Add `networks: [internal, monitoring]` to prometheus, grafana. Add `networks: [monitoring]` to loki, promtail. + +Add to each application service: + +```yaml + deploy: + resources: + limits: + memory: 512M + cpus: '1.0' +``` + +For strategy-engine and backtester, use `memory: 1G` instead. + +- [ ] **Step 2: Verify compose config is valid** + +Run: `docker compose config --quiet` +Expected: No errors + +- [ ] **Step 3: Commit** + +```bash +git add docker-compose.yml +git commit -m "feat: add resource limits and network isolation to docker-compose" +``` + +--- + +## Phase 3: Service-Level Improvements + +### Task 10: Graceful Shutdown for All Services + +**Files:** +- Modify: `services/data-collector/src/data_collector/main.py` +- Modify: `services/strategy-engine/src/strategy_engine/main.py` +- Modify: `services/order-executor/src/order_executor/main.py` +- Modify: `services/portfolio-manager/src/portfolio_manager/main.py` +- Modify: `services/news-collector/src/news_collector/main.py` +- Modify: `services/api/src/trading_api/main.py` + +- [ ] **Step 1: Create a shared shutdown helper** + +Add to `shared/src/shared/shutdown.py`: + +```python +"""Graceful shutdown utilities for services.""" + +import asyncio +import logging +import signal + +logger = logging.getLogger(__name__) + + +class GracefulShutdown: + """Manages graceful shutdown via SIGTERM/SIGINT signals.""" + + def __init__(self) -> None: + self._event = asyncio.Event() + + @property + def is_shutting_down(self) -> bool: + return self._event.is_set() + + async def wait(self) -> None: + await self._event.wait() + + def trigger(self) -> None: + logger.info("shutdown_signal_received") + self._event.set() + + def install_handlers(self) -> None: + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, self.trigger) +``` + +- [ ] **Step 2: Add shutdown to data-collector main loop** + +In `services/data-collector/src/data_collector/main.py`, add at the start of `run()`: + +```python +from shared.shutdown import GracefulShutdown + +shutdown = GracefulShutdown() +shutdown.install_handlers() +``` + +Replace the main `while True` loop condition with `while not shutdown.is_shutting_down`. + +- [ ] **Step 3: Apply same pattern to all other services** + +For each service's `main.py`, add `GracefulShutdown` import, install handlers at start of `run()`, and replace infinite loops with `while not shutdown.is_shutting_down`. + +For strategy-engine: also cancel tasks on shutdown. +For portfolio-manager: also cancel snapshot_loop task. +For news-collector: also cancel all collector loop tasks. + +- [ ] **Step 4: Run tests** + +Run: `pytest -v` +Expected: All tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add shared/src/shared/shutdown.py services/*/src/*/main.py +git commit -m "feat: add graceful shutdown with SIGTERM/SIGINT handlers to all services" +``` + +--- + +### Task 11: Exception Handling Specialization + +**Files:** +- Modify: All service `main.py` files +- Modify: `shared/src/shared/db.py` + +- [ ] **Step 1: Specialize exceptions in data-collector/main.py** + +Replace broad `except Exception` blocks. For example, in the fetch loop: + +```python +# Before: +except Exception as exc: + log.warning("fetch_bar_failed", symbol=symbol, error=str(exc)) + +# After: +except (ConnectionError, TimeoutError, aiohttp.ClientError) as exc: + log.warning("fetch_bar_network_error", symbol=symbol, error=str(exc)) +except (ValueError, KeyError) as exc: + log.warning("fetch_bar_parse_error", symbol=symbol, error=str(exc)) +except Exception as exc: + log.error("fetch_bar_unexpected", symbol=symbol, error=str(exc), exc_info=True) +``` + +- [ ] **Step 2: Specialize exceptions in strategy-engine, order-executor, portfolio-manager, news-collector** + +Apply the same pattern: network errors → warning + retry, parse errors → warning + skip, unexpected → error + exc_info. + +- [ ] **Step 3: Specialize exceptions in db.py** + +In `shared/src/shared/db.py`, the transaction pattern can distinguish: + +```python +except (asyncpg.PostgresError, sqlalchemy.exc.OperationalError) as exc: + await session.rollback() + logger.error("db_operation_error", error=str(exc)) + raise +except Exception: + await session.rollback() + raise +``` + +- [ ] **Step 4: Run tests** + +Run: `pytest -v` +Expected: All tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add services/*/src/*/main.py shared/src/shared/db.py +git commit -m "refactor: specialize exception handling across all services" +``` + +--- + +### Task 12: Fix Stock Selector (Bug Fix + Dedup + Session Reuse) + +**Files:** +- Modify: `services/strategy-engine/src/strategy_engine/stock_selector.py` +- Modify: `services/strategy-engine/tests/test_stock_selector.py` (if exists, otherwise create) + +- [ ] **Step 1: Fix the critical bug on line 217** + +In `stock_selector.py` line 217, replace: +```python +self._session = anthropic_model +``` +with: +```python +self._model = anthropic_model +``` + +- [ ] **Step 2: Extract common JSON parsing function** + +Replace the duplicate parsing logic. Add at module level (replacing `_parse_llm_selections`): + +```python +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" + code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) + if code_block: + raw = code_block.group(1) + else: + array_match = re.search(r"\[.*\]", text, re.DOTALL) + if array_match: + raw = array_match.group(0) + else: + raw = text.strip() + + try: + data = json.loads(raw) + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None + except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list.""" + items = _extract_json_array(text) + if items is None: + return [] + selections = [] + for item in items: + try: + selections.append( + SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + ) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections +``` + +Update `LLMCandidateSource._parse_candidates()` to use `_extract_json_array`: + +```python + def _parse_candidates(self, text: str) -> list[Candidate]: + items = _extract_json_array(text) + if items is None: + return [] + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), + ) + ) + return candidates +``` + +- [ ] **Step 3: Add session reuse to StockSelector** + +Add `_http_session` to `StockSelector.__init__()`: + +```python +self._http_session: aiohttp.ClientSession | None = None +``` + +Add helper method: + +```python +async def _ensure_session(self) -> aiohttp.ClientSession: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + +async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() +``` + +Replace `async with aiohttp.ClientSession() as session:` in both `LLMCandidateSource.get_candidates()` and `StockSelector._llm_final_select()` with session reuse. For `LLMCandidateSource`, accept an optional session parameter. For `StockSelector._llm_final_select()`, use `self._ensure_session()`. + +- [ ] **Step 4: Run tests** + +Run: `pytest services/strategy-engine/tests/ -v` +Expected: All tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add services/strategy-engine/src/strategy_engine/stock_selector.py +git commit -m "fix: fix model attr bug, deduplicate LLM parsing, reuse aiohttp sessions" +``` + +--- + +## Phase 4: API Security + +### Task 13: Bearer Token Authentication + +**Files:** +- Create: `services/api/src/trading_api/dependencies/__init__.py` +- Create: `services/api/src/trading_api/dependencies/auth.py` +- Modify: `services/api/src/trading_api/main.py` + +- [ ] **Step 1: Create auth dependency** + +Create `services/api/src/trading_api/dependencies/__init__.py` (empty file). + +Create `services/api/src/trading_api/dependencies/auth.py`: + +```python +"""Bearer token authentication dependency.""" + +import logging + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from shared.config import Settings + +logger = logging.getLogger(__name__) + +_security = HTTPBearer(auto_error=False) +_settings = Settings() + + +async def verify_token( + credentials: HTTPAuthorizationCredentials | None = Depends(_security), +) -> None: + """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured.""" + token = _settings.api_auth_token.get_secret_value() + if not token: + return # Auth disabled in dev mode + + if credentials is None or credentials.credentials != token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) +``` + +- [ ] **Step 2: Apply auth to all API routes** + +In `services/api/src/trading_api/main.py`, add auth dependency to routers: + +```python +from trading_api.dependencies.auth import verify_token +from fastapi import Depends + +app.include_router(portfolio_router, prefix="/api/v1/portfolio", dependencies=[Depends(verify_token)]) +app.include_router(orders_router, prefix="/api/v1/orders", dependencies=[Depends(verify_token)]) +app.include_router(strategies_router, prefix="/api/v1/strategies", dependencies=[Depends(verify_token)]) +``` + +Log a warning on startup if token is empty: + +```python +@asynccontextmanager +async def lifespan(app: FastAPI): + cfg = Settings() + if not cfg.api_auth_token.get_secret_value(): + logger.warning("API_AUTH_TOKEN not set; API authentication is disabled") + # ... rest of lifespan +``` + +- [ ] **Step 3: Write tests for auth** + +Add to `services/api/tests/test_auth.py`: + +```python +"""Tests for API authentication.""" + +from unittest.mock import patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from trading_api.main import app + + +class TestAuth: + @patch("trading_api.dependencies.auth._settings") + async def test_rejects_missing_token_when_configured(self, mock_settings): + from pydantic import SecretStr + mock_settings.api_auth_token = SecretStr("test-token") + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + resp = await ac.get("/api/v1/portfolio/positions") + assert resp.status_code == 401 + + @patch("trading_api.dependencies.auth._settings") + async def test_accepts_valid_token(self, mock_settings): + from pydantic import SecretStr + mock_settings.api_auth_token = SecretStr("test-token") + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + resp = await ac.get( + "/api/v1/portfolio/positions", + headers={"Authorization": "Bearer test-token"}, + ) + # May fail with 500 if DB not available, but should NOT be 401 + assert resp.status_code != 401 +``` + +- [ ] **Step 4: Run tests** + +Run: `pytest services/api/tests/test_auth.py -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add services/api/src/trading_api/dependencies/ services/api/src/trading_api/main.py services/api/tests/test_auth.py +git commit -m "feat: add Bearer token authentication to API endpoints" +``` + +--- + +### Task 14: CORS & Rate Limiting + +**Files:** +- Modify: `services/api/src/trading_api/main.py` +- Modify: `services/api/pyproject.toml` + +- [ ] **Step 1: Add slowapi dependency** + +Already done in Task 5 (`services/api/pyproject.toml` has `slowapi>=0.1.9,<1`). + +- [ ] **Step 2: Add CORS and rate limiting to main.py** + +In `services/api/src/trading_api/main.py`: + +```python +from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded + +from shared.config import Settings + +cfg = Settings() + +limiter = Limiter(key_func=get_remote_address) +app = FastAPI(title="Trading Platform API") +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +app.add_middleware( + CORSMiddleware, + allow_origins=cfg.cors_origins.split(","), + allow_methods=["GET", "POST"], + allow_headers=["Authorization", "Content-Type"], +) +``` + +- [ ] **Step 3: Add rate limits to order endpoints** + +In `services/api/src/trading_api/routers/orders.py`: + +```python +from slowapi import Limiter +from slowapi.util import get_remote_address + +limiter = Limiter(key_func=get_remote_address) + +@router.get("/") +@limiter.limit("60/minute") +async def get_orders(request: Request, limit: int = 50): + ... +``` + +- [ ] **Step 4: Run tests** + +Run: `pytest services/api/tests/ -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add services/api/src/trading_api/main.py services/api/src/trading_api/routers/ +git commit -m "feat: add CORS middleware and rate limiting to API" +``` + +--- + +### Task 15: API Input Validation & Response Models + +**Files:** +- Modify: `services/api/src/trading_api/routers/portfolio.py` +- Modify: `services/api/src/trading_api/routers/orders.py` + +- [ ] **Step 1: Add Query validation to portfolio.py** + +```python +from fastapi import Query + +@router.get("/snapshots") +async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)): + ... +``` + +- [ ] **Step 2: Add Query validation to orders.py** + +```python +from fastapi import Query + +@router.get("/") +async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)): + ... + +@router.get("/signals") +async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)): + ... +``` + +- [ ] **Step 3: Run tests** + +Run: `pytest services/api/tests/ -v` +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add services/api/src/trading_api/routers/ +git commit -m "feat: add input validation with Query bounds to API endpoints" +``` + +--- + +## Phase 5: Operational Maturity + +### Task 16: Enhanced Ruff Configuration + +**Files:** +- Modify: `pyproject.toml:12-14` + +- [ ] **Step 1: Update ruff config in pyproject.toml** + +Replace the ruff section in root `pyproject.toml`: + +```toml +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "W", "F", "I", "B", "UP", "ASYNC", "PERF", "C4", "RUF"] +ignore = ["E501"] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["F841"] +"*/tests/*" = ["F841"] + +[tool.ruff.lint.isort] +known-first-party = ["shared"] +``` + +- [ ] **Step 2: Auto-fix existing violations** + +Run: `ruff check --fix . && ruff format .` +Expected: Fixes applied + +- [ ] **Step 3: Verify no remaining errors** + +Run: `ruff check . && ruff format --check .` +Expected: No errors + +- [ ] **Step 4: Run tests to verify no regressions** + +Run: `pytest -v` +Expected: All tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add pyproject.toml +git commit -m "chore: enhance ruff lint rules with ASYNC, bugbear, isort, pyupgrade" +``` + +Then commit auto-fixes separately: + +```bash +git add -A +git commit -m "style: auto-fix lint violations from enhanced ruff rules" +``` + +--- + +### Task 17: GitHub Actions CI Pipeline + +**Files:** +- Create: `.github/workflows/ci.yml` + +- [ ] **Step 1: Create CI workflow** + +Create `.github/workflows/ci.yml`: + +```yaml +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install ruff + - run: ruff check . + - run: ruff format --check . + + test: + runs-on: ubuntu-latest + services: + redis: + image: redis:7-alpine + ports: [6379:6379] + options: >- + --health-cmd "redis-cli ping" + --health-interval 5s + --health-timeout 3s + --health-retries 5 + postgres: + image: postgres:16-alpine + env: + POSTGRES_USER: trading + POSTGRES_PASSWORD: trading + POSTGRES_DB: trading + ports: [5432:5432] + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 3s + --health-retries 5 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: | + pip install -e shared/[dev] + pip install -e services/strategy-engine/[dev] + pip install -e services/data-collector/[dev] + pip install -e services/order-executor/[dev] + pip install -e services/portfolio-manager/[dev] + pip install -e services/news-collector/[dev] + pip install -e services/api/[dev] + pip install -e services/backtester/[dev] + pip install pytest-cov + - run: pytest -v --cov=shared/src --cov=services --cov-report=xml --cov-report=term-missing + env: + DATABASE_URL: postgresql+asyncpg://trading:trading@localhost:5432/trading + REDIS_URL: redis://localhost:6379 + - uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.xml + + docker: + runs-on: ubuntu-latest + needs: [lint, test] + if: github.ref == 'refs/heads/master' + steps: + - uses: actions/checkout@v4 + - run: docker compose build --quiet +``` + +- [ ] **Step 2: Commit** + +```bash +mkdir -p .github/workflows +git add .github/workflows/ci.yml +git commit -m "feat: add GitHub Actions CI pipeline with lint, test, docker build" +``` + +--- + +### Task 18: Prometheus Alerting Rules + +**Files:** +- Create: `monitoring/prometheus/alert_rules.yml` +- Modify: `monitoring/prometheus.yml` + +- [ ] **Step 1: Create alert rules** + +Create `monitoring/prometheus/alert_rules.yml`: + +```yaml +groups: + - name: trading-platform + rules: + - alert: ServiceDown + expr: up == 0 + for: 1m + labels: + severity: critical + annotations: + summary: "Service {{ $labels.job }} is down" + description: "{{ $labels.instance }} has been unreachable for 1 minute." + + - alert: HighErrorRate + expr: rate(errors_total[5m]) > 10 + for: 2m + labels: + severity: warning + annotations: + summary: "High error rate on {{ $labels.job }}" + description: "Error rate is {{ $value }} errors/sec over 5 minutes." + + - alert: HighProcessingLatency + expr: histogram_quantile(0.95, rate(processing_seconds_bucket[5m])) > 5 + for: 5m + labels: + severity: warning + annotations: + summary: "High p95 latency on {{ $labels.job }}" + description: "95th percentile processing time is {{ $value }}s." +``` + +- [ ] **Step 2: Reference alert rules in prometheus.yml** + +In `monitoring/prometheus.yml`, add after `global:`: + +```yaml +rule_files: + - "/etc/prometheus/alert_rules.yml" +``` + +Update `docker-compose.yml` prometheus service to mount the file: + +```yaml + prometheus: + volumes: + - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml + - ./monitoring/prometheus/alert_rules.yml:/etc/prometheus/alert_rules.yml +``` + +- [ ] **Step 3: Commit** + +```bash +git add monitoring/prometheus/alert_rules.yml monitoring/prometheus.yml docker-compose.yml +git commit -m "feat: add Prometheus alerting rules for service health, errors, latency" +``` + +--- + +### Task 19: Code Coverage Configuration + +**Files:** +- Modify: `pyproject.toml` + +- [ ] **Step 1: Add pytest-cov config** + +Add to `pyproject.toml`: + +```toml +[tool.coverage.run] +branch = true +source = ["shared/src", "services"] +omit = ["*/tests/*", "*/alembic/*"] + +[tool.coverage.report] +fail_under = 60 +show_missing = true +exclude_lines = [ + "pragma: no cover", + "if __name__", + "if TYPE_CHECKING", +] +``` + +Update pytest addopts: +```toml +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["shared/tests", "services", "cli/tests", "tests"] +addopts = "--import-mode=importlib" +``` + +Note: `--cov` flags are passed explicitly in CI, not in addopts (to avoid slowing local dev). + +- [ ] **Step 2: Verify coverage works** + +Run: `pip install pytest-cov && pytest --cov=shared/src --cov-report=term-missing` +Expected: Coverage report printed, no errors + +- [ ] **Step 3: Commit** + +```bash +git add pyproject.toml +git commit -m "chore: add pytest-cov configuration with 60% minimum coverage threshold" +``` + +--- + +## Summary + +| Phase | Tasks | Estimated Commits | +|-------|-------|-------------------| +| 1: Shared Library | Tasks 1-5 | 5 commits | +| 2: Infrastructure | Tasks 6-9 | 4 commits | +| 3: Service Fixes | Tasks 10-12 | 3 commits | +| 4: API Security | Tasks 13-15 | 3 commits | +| 5: Operations | Tasks 16-19 | 5 commits | +| **Total** | **19 tasks** | **~20 commits** | diff --git a/docs/superpowers/specs/2026-04-02-platform-upgrade-design.md b/docs/superpowers/specs/2026-04-02-platform-upgrade-design.md new file mode 100644 index 0000000..9c84e10 --- /dev/null +++ b/docs/superpowers/specs/2026-04-02-platform-upgrade-design.md @@ -0,0 +1,257 @@ +# Platform Upgrade Design Spec + +**Date**: 2026-04-02 +**Approach**: Bottom-Up (shared library → infra → services → API security → operations) + +--- + +## Phase 1: Shared Library Hardening + +### 1-1. Resilience Module (`shared/src/shared/resilience.py`) +Currently empty. Implement: +- **`retry_async()`** — tenacity-based exponential backoff + jitter decorator. Configurable max retries (default 3), base delay (1s), max delay (30s). +- **`CircuitBreaker`** — Tracks consecutive failures. Opens after N failures (default 5), stays open for configurable cooldown (default 60s), transitions to half-open to test recovery. +- **`timeout()`** — asyncio-based timeout wrapper. Raises `TimeoutError` after configurable duration. +- All decorators composable: `@retry_async() @circuit_breaker() async def call_api():` + +### 1-2. DB Connection Pooling (`shared/src/shared/db.py`) +Add to `create_async_engine()`: +- `pool_size=20` (configurable via `DB_POOL_SIZE`) +- `max_overflow=10` (configurable via `DB_MAX_OVERFLOW`) +- `pool_pre_ping=True` (verify connections before use) +- `pool_recycle=3600` (recycle stale connections) +Add corresponding fields to `Settings`. + +### 1-3. Redis Resilience (`shared/src/shared/broker.py`) +- Add to `redis.asyncio.from_url()`: `socket_keepalive=True`, `health_check_interval=30`, `retry_on_timeout=True` +- Wrap `publish()`, `read_group()`, `ensure_group()` with `@retry_async()` from resilience module +- Add `reconnect()` method for connection loss recovery + +### 1-4. Config Validation (`shared/src/shared/config.py`) +- Add `field_validator` for business logic: `risk_max_position_size > 0`, `health_port` in 1024-65535, `log_level` in valid set +- Change secret fields to `SecretStr`: `alpaca_api_key`, `alpaca_api_secret`, `database_url`, `redis_url`, `telegram_bot_token`, `anthropic_api_key`, `finnhub_api_key` +- Update all consumers to call `.get_secret_value()` where needed + +### 1-5. Dependency Pinning +All `pyproject.toml` files: add upper bounds. +Examples: +- `pydantic>=2.8,<3` +- `redis>=5.0,<6` +- `sqlalchemy[asyncio]>=2.0,<3` +- `numpy>=1.26,<3` +- `pandas>=2.1,<3` +- `anthropic>=0.40,<1` +Run `uv lock` to generate lock file. + +--- + +## Phase 2: Infrastructure Hardening + +### 2-1. Docker Secrets & Environment +- Remove hardcoded `POSTGRES_USER: trading` / `POSTGRES_PASSWORD: trading` from `docker-compose.yml` +- Reference via `${POSTGRES_USER}` / `${POSTGRES_PASSWORD}` from `.env` +- Add comments in `.env.example` marking secret vs config variables + +### 2-2. Dockerfile Optimization (all 7 services) +Pattern for each Dockerfile: +```dockerfile +# Stage 1: builder +FROM python:3.12-slim AS builder +WORKDIR /app +COPY shared/pyproject.toml shared/setup.cfg shared/ +COPY shared/src/ shared/src/ +RUN pip install --no-cache-dir ./shared +COPY services/<name>/pyproject.toml services/<name>/ +COPY services/<name>/src/ services/<name>/src/ +RUN pip install --no-cache-dir ./services/<name> + +# Stage 2: runtime +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +USER appuser +CMD ["python", "-m", "<module>.main"] +``` + +Create root `.dockerignore`: +``` +__pycache__ +*.pyc +.git +.venv +.env +tests/ +docs/ +*.md +.ruff_cache +``` + +### 2-3. Database Index Migration (`003_add_missing_indexes.py`) +New Alembic migration adding: +- `idx_signals_symbol_created` on `signals(symbol, created_at)` +- `idx_orders_symbol_status_created` on `orders(symbol, status, created_at)` +- `idx_trades_order_id` on `trades(order_id)` +- `idx_trades_symbol_traded` on `trades(symbol, traded_at)` +- `idx_portfolio_snapshots_at` on `portfolio_snapshots(snapshot_at)` +- `idx_symbol_scores_symbol` unique on `symbol_scores(symbol)` + +### 2-4. Docker Compose Resource Limits +Add to each service: +```yaml +deploy: + resources: + limits: + memory: 512M + cpus: '1.0' +``` +Strategy-engine and backtester get `memory: 1G` (pandas/numpy usage). + +Add explicit networks: +```yaml +networks: + internal: + driver: bridge + monitoring: + driver: bridge +``` + +--- + +## Phase 3: Service-Level Improvements + +### 3-1. Graceful Shutdown (all services) +Add to each service's `main()`: +```python +shutdown_event = asyncio.Event() + +def _signal_handler(): + log.info("shutdown_signal_received") + shutdown_event.set() + +loop = asyncio.get_event_loop() +loop.add_signal_handler(signal.SIGTERM, _signal_handler) +loop.add_signal_handler(signal.SIGINT, _signal_handler) +``` +Main loops check `shutdown_event.is_set()` to exit gracefully. +API service: add `--timeout-graceful-shutdown 30` to uvicorn CMD. + +### 3-2. Exception Specialization (all services) +Replace broad `except Exception` with layered handling: +- `ConnectionError`, `TimeoutError` → retry via resilience module +- `ValueError`, `KeyError` → log warning, skip item, continue +- `Exception` → top-level only, `exc_info=True` for full traceback + Telegram alert + +Target: reduce 63 broad catches to ~10 top-level safety nets. + +### 3-3. LLM Parsing Deduplication (`stock_selector.py`) +Extract `_extract_json_from_text(text: str) -> list | dict | None`: +- Tries ```` ```json ``` ```` code block extraction +- Falls back to `re.search(r"\[.*\]", text, re.DOTALL)` +- Falls back to raw `json.loads(text.strip())` +Replace 3 duplicate parsing blocks with single call. + +### 3-4. aiohttp Session Reuse (`stock_selector.py`) +- Add `_session: aiohttp.ClientSession | None = None` to `StockSelector` +- Lazy-init in `_ensure_session()`, close in `close()` +- Replace all `async with aiohttp.ClientSession()` with `self._session` + +--- + +## Phase 4: API Security + +### 4-1. Bearer Token Authentication +- Add `api_auth_token: SecretStr = ""` to `Settings` +- Create `dependencies/auth.py` with `verify_token()` dependency +- Apply to all `/api/v1/*` routes via router-level `dependencies=[Depends(verify_token)]` +- If token is empty string → skip auth (dev mode), log warning on startup + +### 4-2. CORS Configuration +```python +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins.split(","), # default: "http://localhost:3000" + allow_methods=["GET", "POST"], + allow_headers=["Authorization", "Content-Type"], +) +``` +Add `cors_origins: str = "http://localhost:3000"` to Settings. + +### 4-3. Rate Limiting +- Add `slowapi` dependency +- Global default: 60 req/min per IP +- Order-related endpoints: 10 req/min per IP +- Return `429 Too Many Requests` with `Retry-After` header + +### 4-4. Input Validation +- All `limit` params: `Query(default=50, ge=1, le=1000)` +- All `days` params: `Query(default=30, ge=1, le=365)` +- Add Pydantic `response_model` to all endpoints (enables auto OpenAPI docs) +- Add `symbol` param validation: uppercase, 1-5 chars, alphanumeric + +--- + +## Phase 5: Operational Maturity + +### 5-1. GitHub Actions CI/CD +File: `.github/workflows/ci.yml` + +**PR trigger** (`pull_request`): +1. Install deps (`uv sync`) +2. Ruff lint + format check +3. pytest with coverage (`--cov --cov-report=xml`) +4. Upload coverage to PR comment + +**Main push** (`push: branches: [master]`): +1. Same lint + test +2. `docker compose build` +3. (Future: push to registry) + +### 5-2. Ruff Rules Enhancement +```toml +[tool.ruff.lint] +select = ["E", "W", "F", "I", "B", "UP", "ASYNC", "PERF", "C4", "RUF"] +ignore = ["E501"] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["F841"] + +[tool.ruff.lint.isort] +known-first-party = ["shared"] +``` +Run `ruff check --fix .` and `ruff format .` to fix existing violations, commit separately. + +### 5-3. Prometheus Alerting +File: `monitoring/prometheus/alert_rules.yml` +Rules: +- `ServiceDown`: `service_up == 0` for 1 min → critical +- `HighErrorRate`: `rate(errors_total[5m]) > 10` → warning +- `HighLatency`: `histogram_quantile(0.95, processing_seconds) > 5` → warning + +Add Alertmanager config with Telegram webhook (reuse existing bot token). +Reference alert rules in `monitoring/prometheus.yml`. + +### 5-4. Code Coverage +Add to root `pyproject.toml`: +```toml +[tool.pytest.ini_options] +addopts = "--cov=shared/src --cov=services --cov-report=term-missing" + +[tool.coverage.run] +branch = true +omit = ["tests/*", "*/alembic/*"] + +[tool.coverage.report] +fail_under = 70 +``` +Add `pytest-cov` to dev dependencies. + +--- + +## Out of Scope +- Kubernetes/Helm charts (premature — Docker Compose sufficient for current scale) +- External secrets manager (Vault, AWS SM — overkill for single-machine deployment) +- OpenTelemetry distributed tracing (add when debugging cross-service issues) +- API versioning beyond `/api/v1/` prefix +- Data retention/partitioning (address when data volume becomes an issue) diff --git a/monitoring/prometheus.yml b/monitoring/prometheus.yml index b6dc853..e177c9c 100644 --- a/monitoring/prometheus.yml +++ b/monitoring/prometheus.yml @@ -1,5 +1,7 @@ global: scrape_interval: 15s +rule_files: + - "/etc/prometheus/alert_rules.yml" scrape_configs: - job_name: "trading-services" authorization: diff --git a/monitoring/prometheus/alert_rules.yml b/monitoring/prometheus/alert_rules.yml new file mode 100644 index 0000000..aca2f1c --- /dev/null +++ b/monitoring/prometheus/alert_rules.yml @@ -0,0 +1,29 @@ +groups: + - name: trading-platform + rules: + - alert: ServiceDown + expr: up == 0 + for: 1m + labels: + severity: critical + annotations: + summary: "Service {{ $labels.job }} is down" + description: "{{ $labels.instance }} has been unreachable for 1 minute." + + - alert: HighErrorRate + expr: rate(errors_total[5m]) > 10 + for: 2m + labels: + severity: warning + annotations: + summary: "High error rate on {{ $labels.job }}" + description: "Error rate is {{ $value }} errors/sec over 5 minutes." + + - alert: HighProcessingLatency + expr: histogram_quantile(0.95, rate(processing_seconds_bucket[5m])) > 5 + for: 5m + labels: + severity: warning + annotations: + summary: "High p95 latency on {{ $labels.job }}" + description: "95th percentile processing time is {{ $value }}s." diff --git a/pyproject.toml b/pyproject.toml index 6938778..08150fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,28 @@ addopts = "--import-mode=importlib" [tool.ruff] target-version = "py312" line-length = 100 + +[tool.ruff.lint] +select = ["E", "W", "F", "I", "B", "UP", "ASYNC", "PERF", "C4", "RUF"] +ignore = ["E501", "RUF012", "B008", "ASYNC240"] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["F841"] +"*/tests/*" = ["F841"] + +[tool.ruff.lint.isort] +known-first-party = ["shared"] + +[tool.coverage.run] +branch = true +source = ["shared/src", "services"] +omit = ["*/tests/*", "*/alembic/*"] + +[tool.coverage.report] +fail_under = 60 +show_missing = true +exclude_lines = [ + "pragma: no cover", + "if __name__", + "if TYPE_CHECKING", +] diff --git a/scripts/backtest_moc.py b/scripts/backtest_moc.py index 92b426b..9307668 100755 --- a/scripts/backtest_moc.py +++ b/scripts/backtest_moc.py @@ -4,11 +4,11 @@ Usage: python scripts/backtest_moc.py """ -import sys import random -from pathlib import Path +import sys +from datetime import UTC, datetime, timedelta from decimal import Decimal -from datetime import datetime, timedelta, timezone +from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT / "services" / "strategy-engine" / "src")) @@ -16,10 +16,11 @@ sys.path.insert(0, str(ROOT / "services" / "strategy-engine")) sys.path.insert(0, str(ROOT / "services" / "backtester" / "src")) sys.path.insert(0, str(ROOT / "shared" / "src")) -from shared.models import Candle # noqa: E402 from backtester.engine import BacktestEngine # noqa: E402 from strategies.moc_strategy import MocStrategy # noqa: E402 +from shared.models import Candle # noqa: E402 + def generate_stock_candles( symbol: str = "AAPL", @@ -38,7 +39,7 @@ def generate_stock_candles( """ candles = [] price = base_price - start_date = datetime(2025, 1, 2, tzinfo=timezone.utc) # Start on a Thursday + start_date = datetime(2025, 1, 2, tzinfo=UTC) # Start on a Thursday trading_day = 0 current_date = start_date @@ -136,27 +137,26 @@ def main(): ] # Parameter grid - param_sets = [] - for rsi_min in [25, 30, 35]: - for rsi_max in [55, 60, 65]: - for sl in [1.5, 2.0, 3.0]: - for ema in [10, 20]: - param_sets.append( - { - "quantity_pct": 0.2, - "stop_loss_pct": sl, - "rsi_min": rsi_min, - "rsi_max": rsi_max, - "ema_period": ema, - "volume_avg_period": 20, - "min_volume_ratio": 0.8, - "buy_start_utc": 19, - "buy_end_utc": 21, - "sell_start_utc": 14, - "sell_end_utc": 15, - "max_positions": 5, - } - ) + param_sets = [ + { + "quantity_pct": 0.2, + "stop_loss_pct": sl, + "rsi_min": rsi_min, + "rsi_max": rsi_max, + "ema_period": ema, + "volume_avg_period": 20, + "min_volume_ratio": 0.8, + "buy_start_utc": 19, + "buy_end_utc": 21, + "sell_start_utc": 14, + "sell_end_utc": 15, + "max_positions": 5, + } + for rsi_min in [25, 30, 35] + for rsi_max in [55, 60, 65] + for sl in [1.5, 2.0, 3.0] + for ema in [10, 20] + ] print(f"\nParameter combinations: {len(param_sets)}") print(f"Stocks: {[s[0] for s in stocks]}") @@ -234,7 +234,7 @@ def main(): print("\n" + "=" * 60) print("WORST 3 PARAMETER SETS") print("=" * 60) - for _rank, (params, profit, trades, sharpe, _) in enumerate(param_results[-3:], 1): + for _rank, (params, profit, _trades, sharpe, _) in enumerate(param_results[-3:], 1): print( f" RSI({params['rsi_min']}-{params['rsi_max']})," f" SL={params['stop_loss_pct']}%, EMA={params['ema_period']}" diff --git a/scripts/stock_screener.py b/scripts/stock_screener.py index 7a5c0ba..7552aa3 100755 --- a/scripts/stock_screener.py +++ b/scripts/stock_screener.py @@ -16,7 +16,7 @@ import asyncio import json import os import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path ROOT = Path(__file__).resolve().parents[1] @@ -195,7 +195,7 @@ async def main_async(top_n: int = 5, universe: list[str] | None = None): print("=" * 60) print("Daily Stock Screener — MOC Strategy") - print(f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}") + print(f"Date: {datetime.now(UTC).strftime('%Y-%m-%d %H:%M UTC')}") print(f"Universe: {len(symbols)} stocks") print("=" * 60) diff --git a/services/api/Dockerfile b/services/api/Dockerfile index b942075..93d2b75 100644 --- a/services/api/Dockerfile +++ b/services/api/Dockerfile @@ -1,11 +1,18 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/api/ services/api/ RUN pip install --no-cache-dir ./services/api -COPY services/strategy-engine/strategies/ /app/strategies/ COPY services/strategy-engine/ services/strategy-engine/ RUN pip install --no-cache-dir ./services/strategy-engine -ENV PYTHONPATH=/app -CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000"] + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY services/strategy-engine/strategies/ /app/strategies/ +ENV PYTHONPATH=/app STRATEGIES_DIR=/app/strategies +USER appuser +CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000", "--timeout-graceful-shutdown", "30"] diff --git a/services/api/pyproject.toml b/services/api/pyproject.toml index fd2598d..95099d2 100644 --- a/services/api/pyproject.toml +++ b/services/api/pyproject.toml @@ -3,11 +3,7 @@ name = "trading-api" version = "0.1.0" description = "REST API for the trading platform" requires-python = ">=3.12" -dependencies = [ - "fastapi>=0.110", - "uvicorn>=0.27", - "trading-shared", -] +dependencies = ["fastapi>=0.110,<1", "uvicorn>=0.27,<1", "slowapi>=0.1.9,<1", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27"] diff --git a/services/api/src/trading_api/dependencies/__init__.py b/services/api/src/trading_api/dependencies/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/api/src/trading_api/dependencies/__init__.py diff --git a/services/api/src/trading_api/dependencies/auth.py b/services/api/src/trading_api/dependencies/auth.py new file mode 100644 index 0000000..a5e76c1 --- /dev/null +++ b/services/api/src/trading_api/dependencies/auth.py @@ -0,0 +1,29 @@ +"""Bearer token authentication dependency.""" + +import logging + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from shared.config import Settings + +logger = logging.getLogger(__name__) + +_security = HTTPBearer(auto_error=False) +_settings = Settings() + + +async def verify_token( + credentials: HTTPAuthorizationCredentials | None = Depends(_security), +) -> None: + """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured.""" + token = _settings.api_auth_token.get_secret_value() + if not token: + return # Auth disabled in dev mode + + if credentials is None or credentials.credentials != token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py index 39f7b43..05c6d2f 100644 --- a/services/api/src/trading_api/main.py +++ b/services/api/src/trading_api/main.py @@ -1,33 +1,71 @@ """Trading Platform REST API.""" +import logging from contextlib import asynccontextmanager -from fastapi import FastAPI +from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address from shared.config import Settings from shared.db import Database +from trading_api.dependencies.auth import verify_token +from trading_api.routers import orders, portfolio, strategies -from trading_api.routers import portfolio, orders, strategies +logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): settings = Settings() - app.state.db = Database(settings.database_url) + if not settings.api_auth_token.get_secret_value(): + logger.warning("API_AUTH_TOKEN not set — authentication is disabled") + app.state.db = Database(settings.database_url.get_secret_value()) await app.state.db.connect() yield await app.state.db.close() +cfg = Settings() + +limiter = Limiter(key_func=get_remote_address) + app = FastAPI( title="Trading Platform API", version="0.1.0", lifespan=lifespan, ) -app.include_router(portfolio.router, prefix="/api/v1/portfolio", tags=["portfolio"]) -app.include_router(orders.router, prefix="/api/v1/orders", tags=["orders"]) -app.include_router(strategies.router, prefix="/api/v1/strategies", tags=["strategies"]) +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +app.add_middleware( + CORSMiddleware, + allow_origins=cfg.cors_origins.split(","), + allow_methods=["GET", "POST"], + allow_headers=["Authorization", "Content-Type"], +) + +app.include_router( + portfolio.router, + prefix="/api/v1/portfolio", + tags=["portfolio"], + dependencies=[Depends(verify_token)], +) +app.include_router( + orders.router, + prefix="/api/v1/orders", + tags=["orders"], + dependencies=[Depends(verify_token)], +) +app.include_router( + strategies.router, + prefix="/api/v1/strategies", + tags=["strategies"], + dependencies=[Depends(verify_token)], +) @app.get("/health") diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py index c69dc10..b664e2a 100644 --- a/services/api/src/trading_api/routers/orders.py +++ b/services/api/src/trading_api/routers/orders.py @@ -2,17 +2,23 @@ import logging -from fastapi import APIRouter, HTTPException, Request -from shared.sa_models import OrderRow, SignalRow +from fastapi import APIRouter, HTTPException, Query, Request +from slowapi import Limiter +from slowapi.util import get_remote_address from sqlalchemy import select +from sqlalchemy.exc import OperationalError + +from shared.sa_models import OrderRow, SignalRow logger = logging.getLogger(__name__) router = APIRouter() +limiter = Limiter(key_func=get_remote_address) @router.get("/") -async def get_orders(request: Request, limit: int = 50): +@limiter.limit("60/minute") +async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)): """Get recent orders.""" try: db = request.app.state.db @@ -35,13 +41,17 @@ async def get_orders(request: Request, limit: int = 50): } for r in rows ] + except OperationalError as exc: + logger.error("Database error fetching orders: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get orders: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve orders") + logger.error("Failed to get orders: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve orders") from exc @router.get("/signals") -async def get_signals(request: Request, limit: int = 50): +@limiter.limit("60/minute") +async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)): """Get recent signals.""" try: db = request.app.state.db @@ -62,6 +72,9 @@ async def get_signals(request: Request, limit: int = 50): } for r in rows ] + except OperationalError as exc: + logger.error("Database error fetching signals: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get signals: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve signals") + logger.error("Failed to get signals: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve signals") from exc diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py index d76d85d..56bee7c 100644 --- a/services/api/src/trading_api/routers/portfolio.py +++ b/services/api/src/trading_api/routers/portfolio.py @@ -2,9 +2,11 @@ import logging -from fastapi import APIRouter, HTTPException, Request -from shared.sa_models import PositionRow +from fastapi import APIRouter, HTTPException, Query, Request from sqlalchemy import select +from sqlalchemy.exc import OperationalError + +from shared.sa_models import PositionRow logger = logging.getLogger(__name__) @@ -29,13 +31,16 @@ async def get_positions(request: Request): } for r in rows ] + except OperationalError as exc: + logger.error("Database error fetching positions: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get positions: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve positions") + logger.error("Failed to get positions: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve positions") from exc @router.get("/snapshots") -async def get_snapshots(request: Request, days: int = 30): +async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)): """Get portfolio snapshots for the last N days.""" try: db = request.app.state.db @@ -49,6 +54,9 @@ async def get_snapshots(request: Request, days: int = 30): } for s in snapshots ] + except OperationalError as exc: + logger.error("Database error fetching snapshots: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get snapshots: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") + logger.error("Failed to get snapshots: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") from exc diff --git a/services/api/src/trading_api/routers/strategies.py b/services/api/src/trading_api/routers/strategies.py index 7ddd54e..157094c 100644 --- a/services/api/src/trading_api/routers/strategies.py +++ b/services/api/src/trading_api/routers/strategies.py @@ -42,6 +42,9 @@ async def list_strategies(): } for s in strategies ] + except (ImportError, FileNotFoundError) as exc: + logger.error("Strategy loading error: %s", exc) + raise HTTPException(status_code=503, detail="Strategy engine unavailable") from exc except Exception as exc: - logger.error("Failed to list strategies: %s", exc) - raise HTTPException(status_code=500, detail="Failed to list strategies") + logger.error("Failed to list strategies: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to list strategies") from exc diff --git a/services/api/tests/test_api.py b/services/api/tests/test_api.py index 669143b..f3b0a47 100644 --- a/services/api/tests/test_api.py +++ b/services/api/tests/test_api.py @@ -1,6 +1,7 @@ """Tests for the REST API.""" from unittest.mock import AsyncMock, patch + from fastapi.testclient import TestClient diff --git a/services/api/tests/test_orders_router.py b/services/api/tests/test_orders_router.py index 0658619..52252c5 100644 --- a/services/api/tests/test_orders_router.py +++ b/services/api/tests/test_orders_router.py @@ -1,10 +1,10 @@ """Tests for orders API router.""" -import pytest from unittest.mock import AsyncMock, MagicMock -from fastapi.testclient import TestClient -from fastapi import FastAPI +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient from trading_api.routers.orders import router diff --git a/services/api/tests/test_portfolio_router.py b/services/api/tests/test_portfolio_router.py index 3bd1b2c..8cd8ff8 100644 --- a/services/api/tests/test_portfolio_router.py +++ b/services/api/tests/test_portfolio_router.py @@ -1,11 +1,11 @@ """Tests for portfolio API router.""" -import pytest from decimal import Decimal from unittest.mock import AsyncMock, MagicMock -from fastapi.testclient import TestClient -from fastapi import FastAPI +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient from trading_api.routers.portfolio import router diff --git a/services/backtester/Dockerfile b/services/backtester/Dockerfile index 9a4f439..1108e42 100644 --- a/services/backtester/Dockerfile +++ b/services/backtester/Dockerfile @@ -1,10 +1,17 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/backtester/ services/backtester/ RUN pip install --no-cache-dir ./services/backtester + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin COPY services/strategy-engine/strategies/ /app/strategies/ ENV STRATEGIES_DIR=/app/strategies ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "backtester.main"] diff --git a/services/backtester/pyproject.toml b/services/backtester/pyproject.toml index 2601d04..034bcf6 100644 --- a/services/backtester/pyproject.toml +++ b/services/backtester/pyproject.toml @@ -3,7 +3,7 @@ name = "backtester" version = "0.1.0" description = "Strategy backtesting engine" requires-python = ">=3.12" -dependencies = ["pandas>=2.0", "numpy>=1.20", "rich>=13.0", "trading-shared"] +dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "rich>=13.0,<14", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/backtester/src/backtester/engine.py b/services/backtester/src/backtester/engine.py index b03715d..fcf48f1 100644 --- a/services/backtester/src/backtester/engine.py +++ b/services/backtester/src/backtester/engine.py @@ -6,10 +6,9 @@ from dataclasses import dataclass, field from decimal import Decimal from typing import Protocol -from shared.models import Candle, Signal - from backtester.metrics import DetailedMetrics, TradeRecord, compute_detailed_metrics from backtester.simulator import OrderSimulator, SimulatedTrade +from shared.models import Candle, Signal class StrategyProtocol(Protocol): @@ -101,7 +100,7 @@ class BacktestEngine: final_balance = simulator.balance if candles: last_price = candles[-1].close - for symbol, qty in simulator.positions.items(): + for qty in simulator.positions.values(): if qty > Decimal("0"): final_balance += qty * last_price elif qty < Decimal("0"): diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py index a4cea76..dbde00b 100644 --- a/services/backtester/src/backtester/main.py +++ b/services/backtester/src/backtester/main.py @@ -17,11 +17,11 @@ _STRATEGIES_DIR = Path( if _STRATEGIES_DIR.parent not in [Path(p) for p in sys.path]: sys.path.insert(0, str(_STRATEGIES_DIR.parent)) -from shared.db import Database # noqa: E402 -from shared.models import Candle # noqa: E402 from backtester.config import BacktestConfig # noqa: E402 from backtester.engine import BacktestEngine # noqa: E402 from backtester.reporter import format_report # noqa: E402 +from shared.db import Database # noqa: E402 +from shared.models import Candle # noqa: E402 async def run_backtest() -> str: @@ -45,7 +45,7 @@ async def run_backtest() -> str: except Exception as exc: raise RuntimeError(f"Failed to load strategy '{config.strategy_name}': {exc}") from exc - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() try: rows = await db.get_candles(config.symbol, config.timeframe, config.candle_limit) diff --git a/services/backtester/src/backtester/metrics.py b/services/backtester/src/backtester/metrics.py index 239cb6f..c7b032b 100644 --- a/services/backtester/src/backtester/metrics.py +++ b/services/backtester/src/backtester/metrics.py @@ -266,7 +266,7 @@ def compute_detailed_metrics( largest_win=largest_win, largest_loss=largest_loss, avg_holding_period=avg_holding, - trade_pairs=[p for p in pairs], + trade_pairs=list(pairs), risk_free_rate=risk_free_rate, recovery_factor=recovery_factor, max_consecutive_losses=max_consec_losses, diff --git a/services/backtester/src/backtester/simulator.py b/services/backtester/src/backtester/simulator.py index 64c88dd..6bce18b 100644 --- a/services/backtester/src/backtester/simulator.py +++ b/services/backtester/src/backtester/simulator.py @@ -1,9 +1,8 @@ """Simulated order executor for backtesting.""" from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from typing import Optional from shared.models import OrderSide, Signal @@ -16,7 +15,7 @@ class SimulatedTrade: quantity: Decimal balance_after: Decimal fee: Decimal = Decimal("0") - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) @dataclass @@ -27,8 +26,8 @@ class OpenPosition: side: OrderSide # BUY = long, SELL = short entry_price: Decimal quantity: Decimal - stop_loss: Optional[Decimal] = None - take_profit: Optional[Decimal] = None + stop_loss: Decimal | None = None + take_profit: Decimal | None = None class OrderSimulator: @@ -70,7 +69,7 @@ class OrderSimulator: remaining: list[OpenPosition] = [] for pos in self.open_positions: triggered = False - exit_price: Optional[Decimal] = None + exit_price: Decimal | None = None if pos.side == OrderSide.BUY: # Long position if pos.stop_loss is not None and candle_low <= pos.stop_loss: @@ -125,12 +124,12 @@ class OrderSimulator: def execute( self, signal: Signal, - timestamp: Optional[datetime] = None, - stop_loss: Optional[Decimal] = None, - take_profit: Optional[Decimal] = None, + timestamp: datetime | None = None, + stop_loss: Decimal | None = None, + take_profit: Decimal | None = None, ) -> bool: """Execute a signal with slippage and fees. Returns True if accepted.""" - ts = timestamp or datetime.now(timezone.utc) + ts = timestamp or datetime.now(UTC) exec_price = self._apply_slippage(signal.price, signal.side) fee = self._calculate_fee(exec_price, signal.quantity) diff --git a/services/backtester/src/backtester/walk_forward.py b/services/backtester/src/backtester/walk_forward.py index c7b7fd8..720ad5e 100644 --- a/services/backtester/src/backtester/walk_forward.py +++ b/services/backtester/src/backtester/walk_forward.py @@ -1,11 +1,11 @@ """Walk-forward analysis for strategy parameter optimization.""" +from collections.abc import Callable from dataclasses import dataclass, field from decimal import Decimal -from typing import Callable -from shared.models import Candle from backtester.engine import BacktestEngine, BacktestResult, StrategyProtocol +from shared.models import Candle @dataclass diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py index 4794e63..f789831 100644 --- a/services/backtester/tests/test_engine.py +++ b/services/backtester/tests/test_engine.py @@ -1,20 +1,19 @@ """Tests for the BacktestEngine.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from unittest.mock import MagicMock - -from shared.models import Candle, Signal, OrderSide - from backtester.engine import BacktestEngine +from shared.models import Candle, OrderSide, Signal + def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle: return Candle( symbol=symbol, timeframe=timeframe, - open_time=datetime.now(timezone.utc), + open_time=datetime.now(UTC), open=Decimal(str(price)), high=Decimal(str(price * 1.01)), low=Decimal(str(price * 0.99)), diff --git a/services/backtester/tests/test_metrics.py b/services/backtester/tests/test_metrics.py index 55f5b6c..13e545e 100644 --- a/services/backtester/tests/test_metrics.py +++ b/services/backtester/tests/test_metrics.py @@ -1,17 +1,16 @@ """Tests for detailed backtest metrics.""" import math -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from decimal import Decimal import pytest - from backtester.metrics import TradeRecord, compute_detailed_metrics def _make_trade(side: str, price: str, minutes_offset: int = 0) -> TradeRecord: return TradeRecord( - time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(minutes=minutes_offset), + time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(minutes=minutes_offset), symbol="AAPL", side=side, price=Decimal(price), @@ -124,7 +123,7 @@ def test_consecutive_losses(): def test_risk_free_rate_affects_sharpe(): """Higher risk-free rate should lower Sharpe ratio.""" - base = datetime(2025, 1, 1, tzinfo=timezone.utc) + base = datetime(2025, 1, 1, tzinfo=UTC) trades = [ TradeRecord( time=base, symbol="AAPL", side="BUY", price=Decimal("100"), quantity=Decimal("1") @@ -184,7 +183,7 @@ def test_daily_returns_populated(): def test_fee_subtracted_from_pnl(): """Fees should be subtracted from trade PnL.""" - base = datetime(2025, 1, 1, tzinfo=timezone.utc) + base = datetime(2025, 1, 1, tzinfo=UTC) trades_with_fees = [ TradeRecord( time=base, diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py index 62e2cdb..f85594f 100644 --- a/services/backtester/tests/test_simulator.py +++ b/services/backtester/tests/test_simulator.py @@ -1,11 +1,12 @@ """Tests for the OrderSimulator.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from shared.models import OrderSide, Signal from backtester.simulator import OrderSimulator +from shared.models import OrderSide, Signal + def make_signal( symbol: str, @@ -135,7 +136,7 @@ def test_stop_loss_triggers(): signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("48000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("50500"), candle_low=Decimal("47500"), # below stop_loss @@ -153,7 +154,7 @@ def test_take_profit_triggers(): signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, take_profit=Decimal("55000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("56000"), # above take_profit candle_low=Decimal("50000"), @@ -171,7 +172,7 @@ def test_stop_not_triggered_within_range(): signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("48000"), take_profit=Decimal("55000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("52000"), candle_low=Decimal("49000"), @@ -212,7 +213,7 @@ def test_short_stop_loss(): signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("52000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("53000"), # above stop_loss candle_low=Decimal("49000"), diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py index 96abb6e..b1aa12c 100644 --- a/services/backtester/tests/test_walk_forward.py +++ b/services/backtester/tests/test_walk_forward.py @@ -1,18 +1,18 @@ """Tests for walk-forward analysis.""" import sys -from pathlib import Path +from datetime import UTC, datetime, timedelta from decimal import Decimal -from datetime import datetime, timedelta, timezone - +from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine")) -from shared.models import Candle from backtester.walk_forward import WalkForwardEngine, WalkForwardResult from strategies.rsi_strategy import RsiStrategy +from shared.models import Candle + def _generate_candles(n=100, base_price=100.0): candles = [] @@ -23,7 +23,7 @@ def _generate_candles(n=100, base_price=100.0): Candle( symbol="AAPL", timeframe="1h", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i), + open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i), open=Decimal(str(price)), high=Decimal(str(price + 5)), low=Decimal(str(price - 5)), diff --git a/services/data-collector/Dockerfile b/services/data-collector/Dockerfile index 8cb8af4..4d154c5 100644 --- a/services/data-collector/Dockerfile +++ b/services/data-collector/Dockerfile @@ -1,8 +1,15 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/data-collector/ services/data-collector/ RUN pip install --no-cache-dir ./services/data-collector + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "data_collector.main"] diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py index b42b34c..2d44848 100644 --- a/services/data-collector/src/data_collector/main.py +++ b/services/data-collector/src/data_collector/main.py @@ -2,6 +2,9 @@ import asyncio +import aiohttp + +from data_collector.config import CollectorConfig from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database @@ -11,8 +14,7 @@ from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.models import Candle from shared.notifier import TelegramNotifier - -from data_collector.config import CollectorConfig +from shared.shutdown import GracefulShutdown # Health check port: base + 0 HEALTH_PORT_OFFSET = 0 @@ -45,8 +47,10 @@ async def fetch_latest_bars( volume=Decimal(str(bar["v"])), ) candles.append(candle) - except Exception as exc: - log.warning("fetch_bar_failed", symbol=symbol, error=str(exc)) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("fetch_bar_network_error", symbol=symbol, error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("fetch_bar_parse_error", symbol=symbol, error=str(exc)) return candles @@ -56,18 +60,18 @@ async def run() -> None: metrics = ServiceMetrics("data_collector") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -83,14 +87,17 @@ async def run() -> None: symbols = config.symbols timeframe = config.timeframes[0] if config.timeframes else "1Day" + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("starting", symbols=symbols, timeframe=timeframe, poll_interval=poll_interval) try: - while True: + while not shutdown.is_shutting_down: # Check if market is open try: is_open = await alpaca.is_market_open() - except Exception: + except (aiohttp.ClientError, ConnectionError, TimeoutError): is_open = False if is_open: @@ -109,7 +116,7 @@ async def run() -> None: await asyncio.sleep(poll_interval) except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "data-collector") raise finally: diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py index ffffa40..51f3aee 100644 --- a/services/data-collector/tests/test_storage.py +++ b/services/data-collector/tests/test_storage.py @@ -1,19 +1,20 @@ """Tests for storage module.""" -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock -from shared.models import Candle +import pytest from data_collector.storage import CandleStorage +from shared.models import Candle + def _make_candle(symbol: str = "AAPL") -> Candle: return Candle( symbol=symbol, timeframe="1m", - open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC), open=Decimal("30000"), high=Decimal("30100"), low=Decimal("29900"), diff --git a/services/news-collector/Dockerfile b/services/news-collector/Dockerfile index a8e5902..7accee2 100644 --- a/services/news-collector/Dockerfile +++ b/services/news-collector/Dockerfile @@ -1,9 +1,17 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/news-collector/ services/news-collector/ RUN pip install --no-cache-dir ./services/news-collector -RUN python -c "import nltk; nltk.download('vader_lexicon', quiet=True)" +RUN python -c "import nltk; nltk.download('vader_lexicon', download_dir='/usr/local/nltk_data')" + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY --from=builder /usr/local/nltk_data /usr/local/nltk_data ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "news_collector.main"] diff --git a/services/news-collector/pyproject.toml b/services/news-collector/pyproject.toml index 14c856a..6e62b70 100644 --- a/services/news-collector/pyproject.toml +++ b/services/news-collector/pyproject.toml @@ -3,12 +3,7 @@ name = "news-collector" version = "0.1.0" description = "News and sentiment data collector service" requires-python = ">=3.12" -dependencies = [ - "trading-shared", - "feedparser>=6.0", - "nltk>=3.8", - "aiohttp>=3.9", -] +dependencies = ["trading-shared", "feedparser>=6.0,<7", "nltk>=3.8,<4", "aiohttp>=3.9,<4"] [project.optional-dependencies] dev = [ diff --git a/services/news-collector/src/news_collector/collectors/fear_greed.py b/services/news-collector/src/news_collector/collectors/fear_greed.py index f79f716..42e8f88 100644 --- a/services/news-collector/src/news_collector/collectors/fear_greed.py +++ b/services/news-collector/src/news_collector/collectors/fear_greed.py @@ -2,7 +2,6 @@ import logging from dataclasses import dataclass -from typing import Optional import aiohttp @@ -26,7 +25,7 @@ class FearGreedCollector(BaseCollector): async def is_available(self) -> bool: return True - async def _fetch_index(self) -> Optional[dict]: + async def _fetch_index(self) -> dict | None: headers = {"User-Agent": "Mozilla/5.0"} try: async with aiohttp.ClientSession() as session: @@ -50,7 +49,7 @@ class FearGreedCollector(BaseCollector): return "Greed" return "Extreme Greed" - async def collect(self) -> Optional[FearGreedResult]: + async def collect(self) -> FearGreedResult | None: data = await self._fetch_index() if data is None: return None diff --git a/services/news-collector/src/news_collector/collectors/fed.py b/services/news-collector/src/news_collector/collectors/fed.py index fce4842..52128e5 100644 --- a/services/news-collector/src/news_collector/collectors/fed.py +++ b/services/news-collector/src/news_collector/collectors/fed.py @@ -3,7 +3,7 @@ import asyncio import logging from calendar import timegm -from datetime import datetime, timezone +from datetime import UTC, datetime import feedparser from nltk.sentiment.vader import SentimentIntensityAnalyzer @@ -76,10 +76,10 @@ class FedCollector(BaseCollector): if published_parsed: try: ts = timegm(published_parsed) - return datetime.fromtimestamp(ts, tz=timezone.utc) + return datetime.fromtimestamp(ts, tz=UTC) except Exception: pass - return datetime.now(timezone.utc) + return datetime.now(UTC) async def collect(self) -> list[NewsItem]: try: diff --git a/services/news-collector/src/news_collector/collectors/finnhub.py b/services/news-collector/src/news_collector/collectors/finnhub.py index 13e3602..67cb455 100644 --- a/services/news-collector/src/news_collector/collectors/finnhub.py +++ b/services/news-collector/src/news_collector/collectors/finnhub.py @@ -1,7 +1,7 @@ """Finnhub news collector with VADER sentiment analysis.""" import logging -from datetime import datetime, timezone +from datetime import UTC, datetime import aiohttp from nltk.sentiment.vader import SentimentIntensityAnalyzer @@ -64,7 +64,7 @@ class FinnhubCollector(BaseCollector): sentiment = sentiment_scores["compound"] ts = article.get("datetime", 0) - published_at = datetime.fromtimestamp(ts, tz=timezone.utc) + published_at = datetime.fromtimestamp(ts, tz=UTC) related = article.get("related", "") symbols = [t.strip() for t in related.split(",") if t.strip()] if related else [] diff --git a/services/news-collector/src/news_collector/collectors/reddit.py b/services/news-collector/src/news_collector/collectors/reddit.py index 226a2f9..4e9d6f5 100644 --- a/services/news-collector/src/news_collector/collectors/reddit.py +++ b/services/news-collector/src/news_collector/collectors/reddit.py @@ -2,7 +2,7 @@ import logging import re -from datetime import datetime, timezone +from datetime import UTC, datetime import aiohttp from nltk.sentiment.vader import SentimentIntensityAnalyzer @@ -78,7 +78,7 @@ class RedditCollector(BaseCollector): symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined))) created_utc = post_data.get("created_utc", 0) - published_at = datetime.fromtimestamp(created_utc, tz=timezone.utc) + published_at = datetime.fromtimestamp(created_utc, tz=UTC) items.append( NewsItem( diff --git a/services/news-collector/src/news_collector/collectors/rss.py b/services/news-collector/src/news_collector/collectors/rss.py index ddf8503..bca0e9f 100644 --- a/services/news-collector/src/news_collector/collectors/rss.py +++ b/services/news-collector/src/news_collector/collectors/rss.py @@ -3,7 +3,7 @@ import asyncio import logging import re -from datetime import datetime, timezone +from datetime import UTC, datetime from time import mktime import feedparser @@ -56,10 +56,10 @@ class RSSCollector(BaseCollector): if parsed_time: try: ts = mktime(parsed_time) - return datetime.fromtimestamp(ts, tz=timezone.utc) + return datetime.fromtimestamp(ts, tz=UTC) except Exception: pass - return datetime.now(timezone.utc) + return datetime.now(UTC) async def collect(self) -> list[NewsItem]: try: diff --git a/services/news-collector/src/news_collector/collectors/sec_edgar.py b/services/news-collector/src/news_collector/collectors/sec_edgar.py index ca1d070..d88518f 100644 --- a/services/news-collector/src/news_collector/collectors/sec_edgar.py +++ b/services/news-collector/src/news_collector/collectors/sec_edgar.py @@ -1,13 +1,13 @@ """SEC EDGAR filing collector (free, no API key required).""" import logging -from datetime import datetime, timezone +from datetime import UTC, datetime import aiohttp from nltk.sentiment.vader import SentimentIntensityAnalyzer -from shared.models import NewsCategory, NewsItem from news_collector.collectors.base import BaseCollector +from shared.models import NewsCategory, NewsItem logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ class SecEdgarCollector(BaseCollector): async def collect(self) -> list[NewsItem]: filings_data = await self._fetch_recent_filings() items = [] - today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + today = datetime.now(UTC).strftime("%Y-%m-%d") for company_data in filings_data: tickers = [t["ticker"] for t in company_data.get("tickers", [])] @@ -87,9 +87,7 @@ class SecEdgarCollector(BaseCollector): headline=headline, summary=desc, url=f"https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany&accession={accession}", - published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace( - tzinfo=timezone.utc - ), + published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace(tzinfo=UTC), symbols=tickers, sentiment=self._vader.polarity_scores(headline)["compound"], category=NewsCategory.FILING, diff --git a/services/news-collector/src/news_collector/collectors/truth_social.py b/services/news-collector/src/news_collector/collectors/truth_social.py index 33ebc86..e2acd88 100644 --- a/services/news-collector/src/news_collector/collectors/truth_social.py +++ b/services/news-collector/src/news_collector/collectors/truth_social.py @@ -2,7 +2,7 @@ import logging import re -from datetime import datetime, timezone +from datetime import UTC, datetime import aiohttp from nltk.sentiment.vader import SentimentIntensityAnalyzer @@ -67,7 +67,7 @@ class TruthSocialCollector(BaseCollector): try: published_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) except Exception: - published_at = datetime.now(timezone.utc) + published_at = datetime.now(UTC) items.append( NewsItem( diff --git a/services/news-collector/src/news_collector/config.py b/services/news-collector/src/news_collector/config.py index 70d98f1..6e78eba 100644 --- a/services/news-collector/src/news_collector/config.py +++ b/services/news-collector/src/news_collector/config.py @@ -5,6 +5,3 @@ from shared.config import Settings class NewsCollectorConfig(Settings): health_port: int = 8084 - finnhub_api_key: str = "" - news_poll_interval: int = 300 - sentiment_aggregate_interval: int = 900 diff --git a/services/news-collector/src/news_collector/main.py b/services/news-collector/src/news_collector/main.py index 3493f7c..c39fa67 100644 --- a/services/news-collector/src/news_collector/main.py +++ b/services/news-collector/src/news_collector/main.py @@ -1,8 +1,18 @@ """News Collector Service — fetches news from multiple sources and aggregates sentiment.""" import asyncio -from datetime import datetime, timezone +from datetime import UTC, datetime +import aiohttp + +from news_collector.collectors.fear_greed import FearGreedCollector +from news_collector.collectors.fed import FedCollector +from news_collector.collectors.finnhub import FinnhubCollector +from news_collector.collectors.reddit import RedditCollector +from news_collector.collectors.rss import RSSCollector +from news_collector.collectors.sec_edgar import SecEdgarCollector +from news_collector.collectors.truth_social import TruthSocialCollector +from news_collector.config import NewsCollectorConfig from shared.broker import RedisBroker from shared.db import Database from shared.events import NewsEvent @@ -11,20 +21,9 @@ from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.models import NewsItem from shared.notifier import TelegramNotifier -from shared.sentiment_models import MarketSentiment from shared.sentiment import SentimentAggregator - -from news_collector.config import NewsCollectorConfig -from news_collector.collectors.finnhub import FinnhubCollector -from news_collector.collectors.rss import RSSCollector -from news_collector.collectors.sec_edgar import SecEdgarCollector -from news_collector.collectors.truth_social import TruthSocialCollector -from news_collector.collectors.reddit import RedditCollector -from news_collector.collectors.fear_greed import FearGreedCollector -from news_collector.collectors.fed import FedCollector - -# Health check port: base + 4 -HEALTH_PORT_OFFSET = 4 +from shared.sentiment_models import MarketSentiment +from shared.shutdown import GracefulShutdown async def run_collector_once(collector, db: Database, broker: RedisBroker) -> int: @@ -53,9 +52,15 @@ async def run_collector_loop(collector, db: Database, broker: RedisBroker, log) collector=collector.name, count=count, ) - except Exception as exc: + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: log.warning( - "collector_error", + "collector_network_error", + collector=collector.name, + error=str(exc), + ) + except (ValueError, KeyError, TypeError) as exc: + log.warning( + "collector_parse_error", collector=collector.name, error=str(exc), ) @@ -74,7 +79,7 @@ async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log) vix=None, fed_stance="neutral", market_regime=_determine_regime(result.fear_greed, None), - updated_at=datetime.now(timezone.utc), + updated_at=datetime.now(UTC), ) await db.upsert_market_sentiment(ms) log.info( @@ -82,8 +87,10 @@ async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log) value=result.fear_greed, label=result.fear_greed_label, ) - except Exception as exc: - log.warning("fear_greed_error", error=str(exc)) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("fear_greed_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("fear_greed_parse_error", error=str(exc)) await asyncio.sleep(collector.poll_interval) @@ -93,14 +100,16 @@ async def run_aggregator_loop(db: Database, interval: int, log) -> None: while True: await asyncio.sleep(interval) try: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) news_items = await db.get_recent_news(hours=24) scores = aggregator.aggregate(news_items, now) for score in scores.values(): await db.upsert_symbol_score(score) log.info("aggregation_complete", symbols=len(scores)) - except Exception as exc: - log.warning("aggregator_error", error=str(exc)) + except (ConnectionError, TimeoutError) as exc: + log.warning("aggregator_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("aggregator_parse_error", error=str(exc)) def _determine_regime(fear_greed: int, vix: float | None) -> str: @@ -115,14 +124,14 @@ async def run() -> None: metrics = ServiceMetrics("news_collector") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) health = HealthCheckServer( "news-collector", @@ -133,7 +142,7 @@ async def run() -> None: metrics.service_up.labels(service="news-collector").set(1) # Build collectors - finnhub = FinnhubCollector(api_key=config.finnhub_api_key) + finnhub = FinnhubCollector(api_key=config.finnhub_api_key.get_secret_value()) rss = RSSCollector() sec = SecEdgarCollector() truth = TruthSocialCollector() @@ -143,6 +152,9 @@ async def run() -> None: news_collectors = [finnhub, rss, sec, truth, reddit, fed] + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info( "starting", collectors=[c.name for c in news_collectors], @@ -151,14 +163,13 @@ async def run() -> None: ) try: - tasks = [] - for collector in news_collectors: - tasks.append( - asyncio.create_task( - run_collector_loop(collector, db, broker, log), - name=f"collector-{collector.name}", - ) + tasks = [ + asyncio.create_task( + run_collector_loop(collector, db, broker, log), + name=f"collector-{collector.name}", ) + for collector in news_collectors + ] tasks.append( asyncio.create_task( run_fear_greed_loop(fear_greed, db, log), @@ -171,9 +182,9 @@ async def run() -> None: name="aggregator-loop", ) ) - await asyncio.gather(*tasks) + await shutdown.wait() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "news-collector") raise finally: diff --git a/services/news-collector/tests/test_fear_greed.py b/services/news-collector/tests/test_fear_greed.py index d483aa6..e8bd8f0 100644 --- a/services/news-collector/tests/test_fear_greed.py +++ b/services/news-collector/tests/test_fear_greed.py @@ -1,8 +1,8 @@ """Tests for CNN Fear & Greed Index collector.""" -import pytest from unittest.mock import AsyncMock, patch +import pytest from news_collector.collectors.fear_greed import FearGreedCollector diff --git a/services/news-collector/tests/test_fed.py b/services/news-collector/tests/test_fed.py index d1a736b..7f1c46c 100644 --- a/services/news-collector/tests/test_fed.py +++ b/services/news-collector/tests/test_fed.py @@ -1,7 +1,8 @@ """Tests for Federal Reserve collector.""" -import pytest from unittest.mock import AsyncMock, patch + +import pytest from news_collector.collectors.fed import FedCollector diff --git a/services/news-collector/tests/test_finnhub.py b/services/news-collector/tests/test_finnhub.py index a4cf169..3af65b8 100644 --- a/services/news-collector/tests/test_finnhub.py +++ b/services/news-collector/tests/test_finnhub.py @@ -1,8 +1,8 @@ """Tests for Finnhub news collector.""" -import pytest from unittest.mock import AsyncMock, patch +import pytest from news_collector.collectors.finnhub import FinnhubCollector diff --git a/services/news-collector/tests/test_main.py b/services/news-collector/tests/test_main.py index 66190dc..f85569a 100644 --- a/services/news-collector/tests/test_main.py +++ b/services/news-collector/tests/test_main.py @@ -1,16 +1,18 @@ """Tests for news collector scheduler.""" +from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock -from datetime import datetime, timezone -from shared.models import NewsCategory, NewsItem + from news_collector.main import run_collector_once +from shared.models import NewsCategory, NewsItem + async def test_run_collector_once_stores_and_publishes(): mock_item = NewsItem( source="test", headline="Test news", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.5, category=NewsCategory.MACRO, ) diff --git a/services/news-collector/tests/test_reddit.py b/services/news-collector/tests/test_reddit.py index 440b173..31b1dc1 100644 --- a/services/news-collector/tests/test_reddit.py +++ b/services/news-collector/tests/test_reddit.py @@ -1,7 +1,8 @@ """Tests for Reddit collector.""" -import pytest from unittest.mock import AsyncMock, patch + +import pytest from news_collector.collectors.reddit import RedditCollector diff --git a/services/news-collector/tests/test_rss.py b/services/news-collector/tests/test_rss.py index e03250a..7242c75 100644 --- a/services/news-collector/tests/test_rss.py +++ b/services/news-collector/tests/test_rss.py @@ -1,8 +1,8 @@ """Tests for RSS news collector.""" -import pytest from unittest.mock import AsyncMock, patch +import pytest from news_collector.collectors.rss import RSSCollector diff --git a/services/news-collector/tests/test_sec_edgar.py b/services/news-collector/tests/test_sec_edgar.py index 5d4f69f..b0faf18 100644 --- a/services/news-collector/tests/test_sec_edgar.py +++ b/services/news-collector/tests/test_sec_edgar.py @@ -1,9 +1,9 @@ """Tests for SEC EDGAR filing collector.""" -import pytest -from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch, MagicMock +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch +import pytest from news_collector.collectors.sec_edgar import SecEdgarCollector @@ -37,7 +37,7 @@ async def test_collect_parses_filings(collector): } mock_datetime = MagicMock(spec=datetime) - mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=timezone.utc) + mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=UTC) mock_datetime.strptime = datetime.strptime with patch.object( diff --git a/services/news-collector/tests/test_truth_social.py b/services/news-collector/tests/test_truth_social.py index 91ddb9d..52f1e46 100644 --- a/services/news-collector/tests/test_truth_social.py +++ b/services/news-collector/tests/test_truth_social.py @@ -1,7 +1,8 @@ """Tests for Truth Social collector.""" -import pytest from unittest.mock import AsyncMock, patch + +import pytest from news_collector.collectors.truth_social import TruthSocialCollector diff --git a/services/order-executor/Dockerfile b/services/order-executor/Dockerfile index bc8b21c..376afec 100644 --- a/services/order-executor/Dockerfile +++ b/services/order-executor/Dockerfile @@ -1,8 +1,15 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/order-executor/ services/order-executor/ RUN pip install --no-cache-dir ./services/order-executor + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "order_executor.main"] diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py index a71e762..fd502cd 100644 --- a/services/order-executor/src/order_executor/executor.py +++ b/services/order-executor/src/order_executor/executor.py @@ -1,18 +1,18 @@ """Order execution logic.""" -import structlog -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from typing import Any, Optional +from typing import Any + +import structlog +from order_executor.risk_manager import RiskManager from shared.broker import RedisBroker from shared.db import Database from shared.events import OrderEvent from shared.models import Order, OrderStatus, OrderType, Signal from shared.notifier import TelegramNotifier -from order_executor.risk_manager import RiskManager - logger = structlog.get_logger() @@ -35,7 +35,7 @@ class OrderExecutor: self.notifier = notifier self.dry_run = dry_run - async def execute(self, signal: Signal) -> Optional[Order]: + async def execute(self, signal: Signal) -> Order | None: """Run risk checks and place an order for the given signal.""" # Fetch buying power from Alpaca balance = await self.exchange.get_buying_power() @@ -71,7 +71,7 @@ class OrderExecutor: if self.dry_run: order.status = OrderStatus.FILLED - order.filled_at = datetime.now(timezone.utc) + order.filled_at = datetime.now(UTC) logger.info( "order_filled_dry_run", side=str(order.side), @@ -87,7 +87,7 @@ class OrderExecutor: type="market", ) order.status = OrderStatus.FILLED - order.filled_at = datetime.now(timezone.utc) + order.filled_at = datetime.now(UTC) logger.info( "order_filled", side=str(order.side), diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py index 51ab286..99f88e1 100644 --- a/services/order-executor/src/order_executor/main.py +++ b/services/order-executor/src/order_executor/main.py @@ -3,6 +3,11 @@ import asyncio from decimal import Decimal +import aiohttp + +from order_executor.config import ExecutorConfig +from order_executor.executor import OrderExecutor +from order_executor.risk_manager import RiskManager from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database @@ -11,10 +16,7 @@ from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier - -from order_executor.config import ExecutorConfig -from order_executor.executor import OrderExecutor -from order_executor.risk_manager import RiskManager +from shared.shutdown import GracefulShutdown # Health check port: base + 2 HEALTH_PORT_OFFSET = 2 @@ -26,18 +28,18 @@ async def run() -> None: metrics = ServiceMetrics("order_executor") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -83,6 +85,9 @@ async def run() -> None: await broker.ensure_group(stream, GROUP) + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("started", stream=stream, dry_run=config.dry_run) try: @@ -94,10 +99,15 @@ async def run() -> None: if event.type == EventType.SIGNAL: await executor.execute(event.data) await broker.ack(stream, GROUP, msg_id) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("pending_network_error", error=str(exc), msg_id=msg_id) + except (ValueError, KeyError, TypeError) as exc: + log.warning("pending_parse_error", error=str(exc), msg_id=msg_id) + await broker.ack(stream, GROUP, msg_id) except Exception as exc: - log.error("pending_failed", error=str(exc), msg_id=msg_id) + log.error("pending_failed", error=str(exc), msg_id=msg_id, exc_info=True) - while True: + while not shutdown.is_shutting_down: messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000) for msg_id, msg in messages: try: @@ -110,8 +120,19 @@ async def run() -> None: service="order-executor", event_type="signal" ).inc() await broker.ack(stream, GROUP, msg_id) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("process_network_error", error=str(exc)) + metrics.errors_total.labels( + service="order-executor", error_type="network" + ).inc() + except (ValueError, KeyError, TypeError) as exc: + log.warning("process_parse_error", error=str(exc)) + await broker.ack(stream, GROUP, msg_id) + metrics.errors_total.labels( + service="order-executor", error_type="validation" + ).inc() except Exception as exc: - log.error("process_failed", error=str(exc)) + log.error("process_failed", error=str(exc), exc_info=True) metrics.errors_total.labels( service="order-executor", error_type="processing" ).inc() diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py index 5a05746..811a862 100644 --- a/services/order-executor/src/order_executor/risk_manager.py +++ b/services/order-executor/src/order_executor/risk_manager.py @@ -1,12 +1,12 @@ """Risk management for order execution.""" +import math +from collections import deque from dataclasses import dataclass -from datetime import datetime, timezone, timedelta +from datetime import UTC, datetime, timedelta from decimal import Decimal -from collections import deque -import math -from shared.models import Signal, OrderSide, Position +from shared.models import OrderSide, Position, Signal @dataclass @@ -123,15 +123,13 @@ class RiskManager: else: self._consecutive_losses += 1 if self._consecutive_losses >= self._max_consecutive_losses: - self._paused_until = datetime.now(timezone.utc) + timedelta( - minutes=self._loss_pause_minutes - ) + self._paused_until = datetime.now(UTC) + timedelta(minutes=self._loss_pause_minutes) def is_paused(self) -> bool: """Check if trading is paused due to consecutive losses.""" if self._paused_until is None: return False - if datetime.now(timezone.utc) >= self._paused_until: + if datetime.now(UTC) >= self._paused_until: self._paused_until = None self._consecutive_losses = 0 return False @@ -233,9 +231,9 @@ class RiskManager: mean_a = sum(returns_a) / len(returns_a) mean_b = sum(returns_b) / len(returns_b) - cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len( - returns_a - ) + cov = sum( + (a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b, strict=True) + ) / len(returns_a) std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a)) std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b)) @@ -280,7 +278,11 @@ class RiskManager: min_len = min(len(r) for r in all_returns) portfolio_returns = [] for i in range(min_len): - pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i) + pr = sum( + w * r[-(min_len - i)] + for w, r in zip(weights, all_returns, strict=False) + if len(r) > i + ) portfolio_returns.append(pr) if not portfolio_returns: diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py index dd823d7..cda6b72 100644 --- a/services/order-executor/tests/test_executor.py +++ b/services/order-executor/tests/test_executor.py @@ -4,11 +4,11 @@ from decimal import Decimal from unittest.mock import AsyncMock, MagicMock import pytest - -from shared.models import OrderSide, OrderStatus, Signal from order_executor.executor import OrderExecutor from order_executor.risk_manager import RiskCheckResult, RiskManager +from shared.models import OrderSide, OrderStatus, Signal + def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal: return Signal( diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py index 3d5175b..66e769c 100644 --- a/services/order-executor/tests/test_risk_manager.py +++ b/services/order-executor/tests/test_risk_manager.py @@ -2,9 +2,9 @@ from decimal import Decimal +from order_executor.risk_manager import RiskManager from shared.models import OrderSide, Position, Signal -from order_executor.risk_manager import RiskManager def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal: diff --git a/services/portfolio-manager/Dockerfile b/services/portfolio-manager/Dockerfile index b1a7681..0fa3f35 100644 --- a/services/portfolio-manager/Dockerfile +++ b/services/portfolio-manager/Dockerfile @@ -1,8 +1,15 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/portfolio-manager/ services/portfolio-manager/ RUN pip install --no-cache-dir ./services/portfolio-manager + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "portfolio_manager.main"] diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py index a6823ae..f885aa8 100644 --- a/services/portfolio-manager/src/portfolio_manager/main.py +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -2,6 +2,10 @@ import asyncio +import sqlalchemy.exc + +from portfolio_manager.config import PortfolioConfig +from portfolio_manager.portfolio import PortfolioTracker from shared.broker import RedisBroker from shared.db import Database from shared.events import Event, OrderEvent @@ -9,9 +13,7 @@ from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier - -from portfolio_manager.config import PortfolioConfig -from portfolio_manager.portfolio import PortfolioTracker +from shared.shutdown import GracefulShutdown ORDERS_STREAM = "orders" @@ -51,8 +53,12 @@ async def snapshot_loop( while True: try: await save_snapshot(db, tracker, notifier, log) + except (sqlalchemy.exc.OperationalError, ConnectionError, TimeoutError) as exc: + log.warning("snapshot_db_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("snapshot_data_error", error=str(exc)) except Exception as exc: - log.error("snapshot_failed", error=str(exc)) + log.error("snapshot_failed", error=str(exc), exc_info=True) await asyncio.sleep(interval_hours * 3600) @@ -61,10 +67,10 @@ async def run() -> None: log = setup_logging("portfolio-manager", config.log_level, config.log_format) metrics = ServiceMetrics("portfolio_manager") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) tracker = PortfolioTracker() health = HealthCheckServer( @@ -76,13 +82,16 @@ async def run() -> None: await health.start() metrics.service_up.labels(service="portfolio-manager").set(1) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() snapshot_task = asyncio.create_task( snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log) ) + shutdown = GracefulShutdown() + shutdown.install_handlers() + GROUP = "portfolio-manager" CONSUMER = "portfolio-1" log.info("service_started", stream=ORDERS_STREAM) @@ -108,12 +117,16 @@ async def run() -> None: service="portfolio-manager", event_type="order" ).inc() await broker.ack(ORDERS_STREAM, GROUP, msg_id) + except (ValueError, KeyError, TypeError) as exc: + log.warning("pending_parse_error", error=str(exc), msg_id=msg_id) + await broker.ack(ORDERS_STREAM, GROUP, msg_id) + metrics.errors_total.labels(service="portfolio-manager", error_type="validation").inc() except Exception as exc: - log.error("pending_process_failed", error=str(exc), msg_id=msg_id) + log.error("pending_process_failed", error=str(exc), msg_id=msg_id, exc_info=True) metrics.errors_total.labels(service="portfolio-manager", error_type="processing").inc() try: - while True: + while not shutdown.is_shutting_down: messages = await broker.read_group(ORDERS_STREAM, GROUP, CONSUMER, count=10, block=1000) for msg_id, msg in messages: try: @@ -134,13 +147,21 @@ async def run() -> None: service="portfolio-manager", event_type="order" ).inc() await broker.ack(ORDERS_STREAM, GROUP, msg_id) + except (ValueError, KeyError, TypeError) as exc: + log.warning("message_parse_error", error=str(exc), msg_id=msg_id) + await broker.ack(ORDERS_STREAM, GROUP, msg_id) + metrics.errors_total.labels( + service="portfolio-manager", error_type="validation" + ).inc() except Exception as exc: - log.exception("message_processing_failed", error=str(exc), msg_id=msg_id) + log.error( + "message_processing_failed", error=str(exc), msg_id=msg_id, exc_info=True + ) metrics.errors_total.labels( service="portfolio-manager", error_type="processing" ).inc() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "portfolio-manager") raise finally: diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py index 365dc1a..c8a6894 100644 --- a/services/portfolio-manager/tests/test_portfolio.py +++ b/services/portfolio-manager/tests/test_portfolio.py @@ -2,9 +2,10 @@ from decimal import Decimal -from shared.models import Order, OrderSide, OrderStatus, OrderType from portfolio_manager.portfolio import PortfolioTracker +from shared.models import Order, OrderSide, OrderStatus, OrderType + def make_order(side: OrderSide, price: str, quantity: str) -> Order: """Helper to create a filled Order.""" diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py index ec5e92d..f2026e2 100644 --- a/services/portfolio-manager/tests/test_snapshot.py +++ b/services/portfolio-manager/tests/test_snapshot.py @@ -1,9 +1,10 @@ """Tests for save_snapshot in portfolio-manager.""" -import pytest from decimal import Decimal from unittest.mock import AsyncMock, MagicMock +import pytest + from shared.models import Position diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile index de635dc..f1484e9 100644 --- a/services/strategy-engine/Dockerfile +++ b/services/strategy-engine/Dockerfile @@ -1,9 +1,16 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/strategy-engine/ services/strategy-engine/ RUN pip install --no-cache-dir ./services/strategy-engine + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin COPY services/strategy-engine/strategies/ /app/strategies/ ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "strategy_engine.main"] diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml index 4f5b6be..e4bfb12 100644 --- a/services/strategy-engine/pyproject.toml +++ b/services/strategy-engine/pyproject.toml @@ -3,11 +3,7 @@ name = "strategy-engine" version = "0.1.0" description = "Plugin-based strategy execution engine" requires-python = ">=3.12" -dependencies = [ - "pandas>=2.0", - "numpy>=1.20", - "trading-shared", -] +dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index 2a9cb43..9fd9c49 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -7,9 +7,3 @@ class StrategyConfig(Settings): symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] timeframes: list[str] = ["1m"] strategy_params: dict = {} - selector_candidates_time: str = "15:00" - selector_filter_time: str = "15:15" - selector_final_time: str = "15:30" - selector_max_picks: int = 3 - anthropic_api_key: str = "" - anthropic_model: str = "claude-sonnet-4-20250514" diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py index d401aee..4b2c468 100644 --- a/services/strategy-engine/src/strategy_engine/engine.py +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -2,11 +2,11 @@ import logging -from shared.broker import RedisBroker -from shared.events import CandleEvent, SignalEvent, Event - from strategies.base import BaseStrategy +from shared.broker import RedisBroker +from shared.events import CandleEvent, Event, SignalEvent + logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class StrategyEngine: try: event = Event.from_dict(raw) except Exception as exc: - logger.warning("Failed to parse event: %s – %s", raw, exc) + logger.warning("Failed to parse event: %s - %s", raw, exc) continue if not isinstance(event, CandleEvent): diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 5a30766..3d73058 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -1,9 +1,11 @@ """Strategy Engine Service entry point.""" import asyncio +import zoneinfo from datetime import datetime from pathlib import Path -import zoneinfo + +import aiohttp from shared.alpaca import AlpacaClient from shared.broker import RedisBroker @@ -13,7 +15,7 @@ from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier from shared.sentiment_models import MarketSentiment - +from shared.shutdown import GracefulShutdown from strategy_engine.config import StrategyConfig from strategy_engine.engine import StrategyEngine from strategy_engine.plugin_loader import load_strategies @@ -63,8 +65,12 @@ async def run_stock_selector( log.info("stock_selector_complete", picks=[s.symbol for s in selections]) else: log.info("stock_selector_no_picks") + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("stock_selector_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("stock_selector_data_error", error=str(exc)) except Exception as exc: - log.error("stock_selector_error", error=str(exc)) + log.error("stock_selector_error", error=str(exc), exc_info=True) await asyncio.sleep(120) # Sleep past this minute else: await asyncio.sleep(30) @@ -76,18 +82,18 @@ async def run() -> None: metrics = ServiceMetrics("strategy_engine") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -97,6 +103,9 @@ async def run() -> None: params = config.strategy_params.get(strategy.name, {}) strategy.configure(params) + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies]) engine = StrategyEngine(broker=broker, strategies=strategies) @@ -117,12 +126,12 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) - if config.anthropic_api_key: + if config.anthropic_api_key.get_secret_value(): selector = StockSelector( db=db, broker=broker, alpaca=alpaca, - anthropic_api_key=config.anthropic_api_key, + anthropic_api_key=config.anthropic_api_key.get_secret_value(), anthropic_model=config.anthropic_model, max_picks=config.selector_max_picks, ) @@ -131,9 +140,9 @@ async def run() -> None: ) log.info("stock_selector_enabled", time=config.selector_final_time) - await asyncio.gather(*tasks) + await shutdown.wait() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "strategy-engine") raise finally: diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py index 62e4160..57680db 100644 --- a/services/strategy-engine/src/strategy_engine/plugin_loader.py +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -5,7 +5,6 @@ import sys from pathlib import Path import yaml - from strategies.base import BaseStrategy diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py index 268d557..8657b93 100644 --- a/services/strategy-engine/src/strategy_engine/stock_selector.py +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -1,9 +1,10 @@ """3-stage stock selector engine: sentiment → technical → LLM.""" +import asyncio import json import logging import re -from datetime import datetime, timezone +from datetime import UTC, datetime import aiohttp @@ -18,18 +19,12 @@ logger = logging.getLogger(__name__) ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" -def _parse_llm_selections(text: str) -> list[SelectedStock]: - """Parse LLM response into SelectedStock list. - - Handles both bare JSON arrays and markdown code blocks. - Returns empty list on any parse error. - """ - # Try to extract JSON from markdown code block first +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) if code_block: raw = code_block.group(1) else: - # Try to find a bare JSON array array_match = re.search(r"\[.*\]", text, re.DOTALL) if array_match: raw = array_match.group(0) @@ -38,27 +33,38 @@ def _parse_llm_selections(text: str) -> list[SelectedStock]: try: data = json.loads(raw) - if not isinstance(data, list): - return [] - selections = [] - for item in data: - if not isinstance(item, dict): - continue - try: - selection = SelectedStock( - symbol=item["symbol"], - side=OrderSide(item["side"]), - conviction=float(item["conviction"]), - reason=item.get("reason", ""), - key_news=item.get("key_news", []), - ) - selections.append(selection) - except (KeyError, ValueError) as e: - logger.warning("Skipping invalid selection item: %s", e) - return selections + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + items = _extract_json_array(text) + if items is None: return [] + selections = [] + for item in items: + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + class SentimentCandidateSource: """Generates candidates from DB sentiment scores.""" @@ -92,7 +98,7 @@ class LLMCandidateSource: self._api_key = api_key self._model = model - async def get_candidates(self) -> list[Candidate]: + async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]: news_items = await self._db.get_recent_news(hours=24) if not news_items: return [] @@ -110,26 +116,29 @@ class LLMCandidateSource: "Headlines:\n" + "\n".join(headlines) ) + own_session = session is None + if own_session: + session = aiohttp.ClientSession() + try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM candidate source error %d: %s", resp.status, body) - return [] - data = await resp.json() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" @@ -141,40 +150,32 @@ class LLMCandidateSource: except Exception as e: logger.error("LLMCandidateSource error: %s", e) return [] + finally: + if own_session: + await session.close() def _parse_candidates(self, text: str) -> list[Candidate]: - code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) - if code_block: - raw = code_block.group(1) - else: - array_match = re.search(r"\[.*\]", text, re.DOTALL) - raw = array_match.group(0) if array_match else text.strip() + items = _extract_json_array(text) + if items is None: + return [] - try: - items = json.loads(raw) - if not isinstance(items, list): - return [] - candidates = [] - for item in items: - if not isinstance(item, dict): - continue - try: - direction_str = item.get("direction", "BUY") - direction = OrderSide(direction_str) - except ValueError: - direction = None - candidates.append( - Candidate( - symbol=item["symbol"], - source="llm", - direction=direction, - score=float(item.get("score", 0.5)), - reason=item.get("reason", ""), - ) + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), ) - return candidates - except (json.JSONDecodeError, TypeError, KeyError): - return [] + ) + return candidates def _compute_rsi(closes: list[float], period: int = 14) -> float: @@ -217,6 +218,18 @@ class StockSelector: self._api_key = anthropic_api_key self._model = anthropic_model self._max_picks = max_picks + self._http_session: aiohttp.ClientSession | None = None + self._session_lock = asyncio.Lock() + + async def _ensure_session(self) -> aiohttp.ClientSession: + async with self._session_lock: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() async def select(self) -> list[SelectedStock]: """Run the full 3-stage pipeline and return selected stocks.""" @@ -235,8 +248,9 @@ class StockSelector: sentiment_source = SentimentCandidateSource(self._db) llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + session = await self._ensure_session() sentiment_candidates = await sentiment_source.get_candidates() - llm_candidates = await llm_source.get_candidates() + llm_candidates = await llm_source.get_candidates(session=session) candidates = self._merge_candidates(sentiment_candidates, llm_candidates) if not candidates: @@ -253,7 +267,7 @@ class StockSelector: selections = await self._llm_final_select(filtered, market_sentiment) # Persist and publish - today = datetime.now(timezone.utc).date() + today = datetime.now(UTC).date() sentiment_snapshot = { "fear_greed": market_sentiment.fear_greed, "market_regime": market_sentiment.market_regime, @@ -372,25 +386,25 @@ class StockSelector: ) try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM final select error %d: %s", resp.status, body) - return [] - data = await resp.json() + session = await self._ensure_session() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py index d5be675..1d9d289 100644 --- a/services/strategy-engine/strategies/base.py +++ b/services/strategy-engine/strategies/base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from collections import deque from decimal import Decimal -from typing import Optional import pandas as pd @@ -102,7 +101,7 @@ class BaseStrategy(ABC): def _calculate_atr_stops( self, entry_price: Decimal, side: str - ) -> tuple[Optional[Decimal], Optional[Decimal]]: + ) -> tuple[Decimal | None, Decimal | None]: """Calculate ATR-based stop-loss and take-profit. Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data. @@ -131,7 +130,7 @@ class BaseStrategy(ABC): return sl, tp - def _apply_filters(self, signal: Signal) -> Optional[Signal]: + def _apply_filters(self, signal: Signal) -> Signal | None: """Apply all filters to a signal. Returns signal with SL/TP or None if filtered out.""" if signal is None: return None diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py index ebe7967..02ff09a 100644 --- a/services/strategy-engine/strategies/bollinger_strategy.py +++ b/services/strategy-engine/strategies/bollinger_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py index ba92485..f562918 100644 --- a/services/strategy-engine/strategies/combined_strategy.py +++ b/services/strategy-engine/strategies/combined_strategy.py @@ -2,7 +2,7 @@ from decimal import Decimal -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py index 68d0ba3..9c181f3 100644 --- a/services/strategy-engine/strategies/ema_crossover_strategy.py +++ b/services/strategy-engine/strategies/ema_crossover_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py index 283bfe5..491252e 100644 --- a/services/strategy-engine/strategies/grid_strategy.py +++ b/services/strategy-engine/strategies/grid_strategy.py @@ -1,9 +1,8 @@ from decimal import Decimal -from typing import Optional import numpy as np -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -17,7 +16,7 @@ class GridStrategy(BaseStrategy): self._grid_count: int = 5 self._quantity: Decimal = Decimal("0.01") self._grid_levels: list[float] = [] - self._last_zone: Optional[int] = None + self._last_zone: int | None = None self._exit_threshold_pct: float = 5.0 self._out_of_range: bool = False self._in_position: bool = False # Track if we have any grid positions diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py index 3c713e6..01637b7 100644 --- a/services/strategy-engine/strategies/indicators/__init__.py +++ b/services/strategy-engine/strategies/indicators/__init__.py @@ -1,21 +1,21 @@ """Reusable technical indicator functions.""" -from strategies.indicators.trend import ema, sma, macd, adx -from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels from strategies.indicators.momentum import rsi, stochastic -from strategies.indicators.volume import volume_sma, volume_ratio, obv +from strategies.indicators.trend import adx, ema, macd, sma +from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels +from strategies.indicators.volume import obv, volume_ratio, volume_sma __all__ = [ - "ema", - "sma", - "macd", "adx", "atr", "bollinger_bands", + "ema", "keltner_channels", + "macd", + "obv", "rsi", + "sma", "stochastic", - "volume_sma", "volume_ratio", - "obv", + "volume_sma", ] diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py index c479452..a82210b 100644 --- a/services/strategy-engine/strategies/indicators/momentum.py +++ b/services/strategy-engine/strategies/indicators/momentum.py @@ -1,7 +1,7 @@ """Momentum indicators: RSI, Stochastic.""" -import pandas as pd import numpy as np +import pandas as pd def rsi(closes: pd.Series, period: int = 14) -> pd.Series: diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py index c94a071..1085199 100644 --- a/services/strategy-engine/strategies/indicators/trend.py +++ b/services/strategy-engine/strategies/indicators/trend.py @@ -1,7 +1,7 @@ """Trend indicators: EMA, SMA, MACD, ADX.""" -import pandas as pd import numpy as np +import pandas as pd def sma(series: pd.Series, period: int) -> pd.Series: diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py index c16143e..da82f26 100644 --- a/services/strategy-engine/strategies/indicators/volatility.py +++ b/services/strategy-engine/strategies/indicators/volatility.py @@ -1,7 +1,7 @@ """Volatility indicators: ATR, Bollinger Bands, Keltner Channels.""" -import pandas as pd import numpy as np +import pandas as pd def atr( diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py index 502f1ce..d7c6471 100644 --- a/services/strategy-engine/strategies/indicators/volume.py +++ b/services/strategy-engine/strategies/indicators/volume.py @@ -1,7 +1,7 @@ """Volume indicators: Volume SMA, Volume Ratio, OBV.""" -import pandas as pd import numpy as np +import pandas as pd def volume_sma(volumes: pd.Series, period: int = 20) -> pd.Series: diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py index 356a42b..b5aea07 100644 --- a/services/strategy-engine/strategies/macd_strategy.py +++ b/services/strategy-engine/strategies/macd_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py index 7eaa59e..cbc8440 100644 --- a/services/strategy-engine/strategies/moc_strategy.py +++ b/services/strategy-engine/strategies/moc_strategy.py @@ -8,12 +8,12 @@ Rules: """ from collections import deque -from decimal import Decimal from datetime import datetime +from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py index 0646d8c..2df080d 100644 --- a/services/strategy-engine/strategies/rsi_strategy.py +++ b/services/strategy-engine/strategies/rsi_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py index ef2ae14..67b5c23 100644 --- a/services/strategy-engine/strategies/volume_profile_strategy.py +++ b/services/strategy-engine/strategies/volume_profile_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import numpy as np -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -137,7 +137,7 @@ class VolumeProfileStrategy(BaseStrategy): if result is None: return None - poc, va_low, va_high, hvn_levels, lvn_levels = result + poc, va_low, va_high, hvn_levels, _lvn_levels = result if close < va_low: self._was_below_va = True diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py index d64950e..4ee4952 100644 --- a/services/strategy-engine/strategies/vwap_strategy.py +++ b/services/strategy-engine/strategies/vwap_strategy.py @@ -1,7 +1,7 @@ from collections import deque from decimal import Decimal -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -107,7 +107,7 @@ class VwapStrategy(BaseStrategy): # Standard deviation of (TP - VWAP) for bands std_dev = 0.0 if len(self._tp_values) >= 2: - diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)] + diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values, strict=True)] mean_diff = sum(diffs) / len(diffs) variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs) std_dev = variance**0.5 diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py index ae9ca05..66adec7 100644 --- a/services/strategy-engine/tests/test_base_filters.py +++ b/services/strategy-engine/tests/test_base_filters.py @@ -5,12 +5,13 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -from shared.models import Candle, Signal, OrderSide from strategies.base import BaseStrategy +from shared.models import Candle, OrderSide, Signal + class DummyStrategy(BaseStrategy): name = "dummy" @@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None): return Candle( symbol="AAPL", timeframe="1h", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(h)), low=Decimal(str(lo)), diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py index 8261377..70ec66e 100644 --- a/services/strategy-engine/tests/test_bollinger_strategy.py +++ b/services/strategy-engine/tests/test_bollinger_strategy.py @@ -1,18 +1,18 @@ """Tests for the Bollinger Bands strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.bollinger_strategy import BollingerStrategy from shared.models import Candle, OrderSide -from strategies.bollinger_strategy import BollingerStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py index 8a4dc74..6a15250 100644 --- a/services/strategy-engine/tests/test_combined_strategy.py +++ b/services/strategy-engine/tests/test_combined_strategy.py @@ -5,13 +5,14 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -import pytest -from shared.models import Candle, Signal, OrderSide -from strategies.combined_strategy import CombinedStrategy +import pytest from strategies.base import BaseStrategy +from strategies.combined_strategy import CombinedStrategy + +from shared.models import Candle, OrderSide, Signal class AlwaysBuyStrategy(BaseStrategy): @@ -74,7 +75,7 @@ def _candle(price=100.0): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price + 10)), low=Decimal(str(price - 10)), diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py index 7028eb0..af2b587 100644 --- a/services/strategy-engine/tests/test_ema_crossover_strategy.py +++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py @@ -1,18 +1,18 @@ """Tests for the EMA Crossover strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.ema_crossover_strategy import EmaCrossoverStrategy from shared.models import Candle, OrderSide -from strategies.ema_crossover_strategy import EmaCrossoverStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py index 2623027..fa888b5 100644 --- a/services/strategy-engine/tests/test_engine.py +++ b/services/strategy-engine/tests/test_engine.py @@ -1,21 +1,21 @@ """Tests for the StrategyEngine.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from unittest.mock import AsyncMock, MagicMock import pytest +from strategy_engine.engine import StrategyEngine -from shared.models import Candle, Signal, OrderSide from shared.events import CandleEvent -from strategy_engine.engine import StrategyEngine +from shared.models import Candle, OrderSide, Signal def make_candle_event() -> dict: candle = Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("50100"), low=Decimal("49900"), diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py index 878b900..f697012 100644 --- a/services/strategy-engine/tests/test_grid_strategy.py +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -1,18 +1,18 @@ """Tests for the Grid strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.grid_strategy import GridStrategy from shared.models import Candle, OrderSide -from strategies.grid_strategy import GridStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py index 481569b..3147fc4 100644 --- a/services/strategy-engine/tests/test_indicators.py +++ b/services/strategy-engine/tests/test_indicators.py @@ -5,14 +5,13 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -import pandas as pd import numpy as np +import pandas as pd import pytest - -from strategies.indicators.trend import sma, ema, macd, adx -from strategies.indicators.volatility import atr, bollinger_bands from strategies.indicators.momentum import rsi, stochastic -from strategies.indicators.volume import volume_sma, volume_ratio, obv +from strategies.indicators.trend import adx, ema, macd, sma +from strategies.indicators.volatility import atr, bollinger_bands +from strategies.indicators.volume import obv, volume_ratio, volume_sma class TestTrend: diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py index 556fd4c..7fac16f 100644 --- a/services/strategy-engine/tests/test_macd_strategy.py +++ b/services/strategy-engine/tests/test_macd_strategy.py @@ -1,18 +1,18 @@ """Tests for the MACD strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.macd_strategy import MacdStrategy from shared.models import Candle, OrderSide -from strategies.macd_strategy import MacdStrategy def _candle(price: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price)), low=Decimal(str(price)), diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py index 1928a28..076e846 100644 --- a/services/strategy-engine/tests/test_moc_strategy.py +++ b/services/strategy-engine/tests/test_moc_strategy.py @@ -5,19 +5,20 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from shared.models import Candle, OrderSide from strategies.moc_strategy import MocStrategy +from shared.models import Candle, OrderSide + def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None): op = open_price if open_price is not None else price - 1 # Default: bullish return Candle( symbol="AAPL", timeframe="5Min", - open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc), + open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC), open=Decimal(str(op)), high=Decimal(str(price + 1)), low=Decimal(str(min(op, price) - 1)), diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py index 671a9d3..922bfc2 100644 --- a/services/strategy-engine/tests/test_multi_symbol.py +++ b/services/strategy-engine/tests/test_multi_symbol.py @@ -9,11 +9,13 @@ import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime +from decimal import Decimal + from strategy_engine.engine import StrategyEngine + from shared.events import CandleEvent from shared.models import Candle -from decimal import Decimal -from datetime import datetime, timezone @pytest.mark.asyncio @@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams(): candle_btc = Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), @@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams(): candle_eth = Candle( symbol="MSFT", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal("3000"), high=Decimal("3100"), low=Decimal("2900"), diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py index 5191fc3..7bd450f 100644 --- a/services/strategy-engine/tests/test_plugin_loader.py +++ b/services/strategy-engine/tests/test_plugin_loader.py @@ -2,10 +2,8 @@ from pathlib import Path - from strategy_engine.plugin_loader import load_strategies - STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py index 6d31fd5..6c74f0b 100644 --- a/services/strategy-engine/tests/test_rsi_strategy.py +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -1,18 +1,18 @@ """Tests for the RSI strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.rsi_strategy import RsiStrategy from shared.models import Candle, OrderSide -from strategies.rsi_strategy import RsiStrategy def make_candle(close: float, idx: int = 0) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py index ff9d09c..76b8541 100644 --- a/services/strategy-engine/tests/test_stock_selector.py +++ b/services/strategy-engine/tests/test_stock_selector.py @@ -1,12 +1,12 @@ """Tests for stock selector engine.""" +from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock -from datetime import datetime, timezone - from strategy_engine.stock_selector import ( SentimentCandidateSource, StockSelector, + _extract_json_array, _parse_llm_selections, ) @@ -60,6 +60,37 @@ def test_parse_llm_selections_with_markdown(): assert selections[0].symbol == "TSLA" +def test_extract_json_array_from_markdown(): + text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL", "score": 0.9}] + + +def test_extract_json_array_bare(): + text = '[{"symbol": "TSLA"}]' + result = _extract_json_array(text) + assert result == [{"symbol": "TSLA"}] + + +def test_extract_json_array_invalid(): + assert _extract_json_array("not json") is None + + +def test_extract_json_array_filters_non_dicts(): + text = '[{"symbol": "AAPL"}, "bad", 42]' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL"}] + + +async def test_selector_close(): + selector = StockSelector( + db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + # No session yet - close should be safe + await selector.close() + assert selector._http_session is None + + async def test_selector_blocks_on_risk_off(): mock_db = MagicMock() mock_db.get_latest_market_sentiment = AsyncMock( @@ -69,7 +100,7 @@ async def test_selector_blocks_on_risk_off(): "vix": 35.0, "fed_stance": "neutral", "market_regime": "risk_off", - "updated_at": datetime.now(timezone.utc), + "updated_at": datetime.now(UTC), } ) diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py index debab1f..0d9607a 100644 --- a/services/strategy-engine/tests/test_strategy_validation.py +++ b/services/strategy-engine/tests/test_strategy_validation.py @@ -1,13 +1,11 @@ import pytest - -from strategies.rsi_strategy import RsiStrategy -from strategies.macd_strategy import MacdStrategy from strategies.bollinger_strategy import BollingerStrategy from strategies.ema_crossover_strategy import EmaCrossoverStrategy from strategies.grid_strategy import GridStrategy -from strategies.vwap_strategy import VwapStrategy +from strategies.macd_strategy import MacdStrategy +from strategies.rsi_strategy import RsiStrategy from strategies.volume_profile_strategy import VolumeProfileStrategy - +from strategies.vwap_strategy import VwapStrategy # ── RSI ────────────────────────────────────────────────────────────────── diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py index 65ee2e8..f47898c 100644 --- a/services/strategy-engine/tests/test_volume_profile_strategy.py +++ b/services/strategy-engine/tests/test_volume_profile_strategy.py @@ -1,18 +1,18 @@ """Tests for the Volume Profile strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.volume_profile_strategy import VolumeProfileStrategy from shared.models import Candle, OrderSide -from strategies.volume_profile_strategy import VolumeProfileStrategy def make_candle(close: float, volume: float = 1.0) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), @@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection(): # Create a profile with very high volume at price ~100 and low volume elsewhere # Prices range from 90 to 110, heavy volume concentrated at 100 - candles_data = [] # Low volume at extremes - for p in [90, 91, 92, 109, 110]: - candles_data.append((p, 1.0)) + candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]] # Very high volume around 100 - for _ in range(15): - candles_data.append((100, 100.0)) + candles_data.extend((100, 100.0) for _ in range(15)) for price, vol in candles_data: strategy.on_candle(make_candle(price, vol)) @@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection(): # Access the internal method to verify HVN detection result = strategy._compute_value_area() assert result is not None - poc, va_low, va_high, hvn_levels, lvn_levels = result + _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result # The bin containing price ~100 should have very high volume -> HVN assert len(hvn_levels) > 0 diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py index 2c34b01..078d0cf 100644 --- a/services/strategy-engine/tests/test_vwap_strategy.py +++ b/services/strategy-engine/tests/test_vwap_strategy.py @@ -1,11 +1,11 @@ """Tests for the VWAP strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.vwap_strategy import VwapStrategy from shared.models import Candle, OrderSide -from strategies.vwap_strategy import VwapStrategy def make_candle( @@ -20,7 +20,7 @@ def make_candle( if low is None: low = close if open_time is None: - open_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + open_time = datetime(2024, 1, 1, tzinfo=UTC) return Candle( symbol="AAPL", timeframe="1m", @@ -111,11 +111,11 @@ def test_vwap_daily_reset(): """Candles from two different dates cause VWAP to reset.""" strategy = _configured_strategy() - day1 = datetime(2024, 1, 1, tzinfo=timezone.utc) - day2 = datetime(2024, 1, 2, tzinfo=timezone.utc) + day1 = datetime(2024, 1, 1, tzinfo=UTC) + day2 = datetime(2024, 1, 2, tzinfo=UTC) # Feed 35 candles on day 1 to build VWAP state - for i in range(35): + for _i in range(35): strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1)) # Verify state is built up diff --git a/shared/alembic/versions/001_initial_schema.py b/shared/alembic/versions/001_initial_schema.py index 2bdaafc..7b744ee 100644 --- a/shared/alembic/versions/001_initial_schema.py +++ b/shared/alembic/versions/001_initial_schema.py @@ -5,16 +5,16 @@ Revises: Create Date: 2026-04-01 """ -from typing import Sequence, Union +from collections.abc import Sequence -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "001" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: diff --git a/shared/alembic/versions/002_news_sentiment_tables.py b/shared/alembic/versions/002_news_sentiment_tables.py index 402ff87..d85a634 100644 --- a/shared/alembic/versions/002_news_sentiment_tables.py +++ b/shared/alembic/versions/002_news_sentiment_tables.py @@ -5,15 +5,15 @@ Revises: 001 Create Date: 2026-04-02 """ -from typing import Sequence, Union +from collections.abc import Sequence -from alembic import op import sqlalchemy as sa +from alembic import op revision: str = "002" -down_revision: Union[str, None] = "001" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = "001" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: diff --git a/shared/alembic/versions/003_add_missing_indexes.py b/shared/alembic/versions/003_add_missing_indexes.py new file mode 100644 index 0000000..7a252d4 --- /dev/null +++ b/shared/alembic/versions/003_add_missing_indexes.py @@ -0,0 +1,35 @@ +"""Add missing indexes for common query patterns. + +Revision ID: 003 +Revises: 002 +Create Date: 2026-04-02 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "003" +down_revision: str | None = "002" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_index("idx_signals_symbol_created", "signals", ["symbol", "created_at"]) + op.create_index( + "idx_orders_symbol_status_created", "orders", ["symbol", "status", "created_at"] + ) + op.create_index("idx_trades_order_id", "trades", ["order_id"]) + op.create_index("idx_trades_symbol_traded", "trades", ["symbol", "traded_at"]) + op.create_index("idx_portfolio_snapshots_at", "portfolio_snapshots", ["snapshot_at"]) + op.create_index("idx_symbol_scores_symbol", "symbol_scores", ["symbol"], unique=True) + + +def downgrade() -> None: + op.drop_index("idx_symbol_scores_symbol", table_name="symbol_scores") + op.drop_index("idx_portfolio_snapshots_at", table_name="portfolio_snapshots") + op.drop_index("idx_trades_symbol_traded", table_name="trades") + op.drop_index("idx_trades_order_id", table_name="trades") + op.drop_index("idx_orders_symbol_status_created", table_name="orders") + op.drop_index("idx_signals_symbol_created", table_name="signals") diff --git a/shared/alembic/versions/004_add_signal_detail_columns.py b/shared/alembic/versions/004_add_signal_detail_columns.py new file mode 100644 index 0000000..4009b6e --- /dev/null +++ b/shared/alembic/versions/004_add_signal_detail_columns.py @@ -0,0 +1,25 @@ +"""Add conviction, stop_loss, take_profit columns to signals table. + +Revision ID: 004 +Revises: 003 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "004" +down_revision = "003" + + +def upgrade(): + op.add_column( + "signals", sa.Column("conviction", sa.Float, nullable=False, server_default="1.0") + ) + op.add_column("signals", sa.Column("stop_loss", sa.Numeric, nullable=True)) + op.add_column("signals", sa.Column("take_profit", sa.Numeric, nullable=True)) + + +def downgrade(): + op.drop_column("signals", "take_profit") + op.drop_column("signals", "stop_loss") + op.drop_column("signals", "conviction") diff --git a/shared/pyproject.toml b/shared/pyproject.toml index 830088d..eb74a11 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -4,28 +4,22 @@ version = "0.1.0" description = "Shared models, events, and utilities for trading platform" requires-python = ">=3.12" dependencies = [ - "pydantic>=2.0", - "pydantic-settings>=2.0", - "redis>=5.0", - "asyncpg>=0.29", - "sqlalchemy[asyncio]>=2.0", - "alembic>=1.13", - "structlog>=24.0", - "prometheus-client>=0.20", - "pyyaml>=6.0", - "aiohttp>=3.9", - "rich>=13.0", + "pydantic>=2.8,<3", + "pydantic-settings>=2.0,<3", + "redis>=5.0,<6", + "asyncpg>=0.29,<1", + "sqlalchemy[asyncio]>=2.0,<3", + "alembic>=1.13,<2", + "structlog>=24.0,<25", + "prometheus-client>=0.20,<1", + "pyyaml>=6.0,<7", + "aiohttp>=3.9,<4", + "rich>=13.0,<14", ] [project.optional-dependencies] -dev = [ - "pytest>=8.0", - "pytest-asyncio>=0.23", - "ruff>=0.4", -] -claude = [ - "anthropic>=0.40", -] +dev = ["pytest>=8.0,<9", "pytest-asyncio>=0.23,<1", "ruff>=0.4,<1"] +claude = ["anthropic>=0.40,<1"] [build-system] requires = ["hatchling"] diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py index fbe4576..2b96714 100644 --- a/shared/src/shared/broker.py +++ b/shared/src/shared/broker.py @@ -5,13 +5,21 @@ from typing import Any import redis.asyncio +from shared.resilience import retry_async + class RedisBroker: """Async Redis Streams broker for publishing and reading events.""" def __init__(self, redis_url: str) -> None: - self._redis = redis.asyncio.from_url(redis_url) + self._redis = redis.asyncio.from_url( + redis_url, + socket_keepalive=True, + health_check_interval=30, + retry_on_timeout=True, + ) + @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,)) async def publish(self, stream: str, data: dict[str, Any]) -> None: """Publish a message to a Redis stream.""" payload = json.dumps(data) @@ -25,6 +33,7 @@ class RedisBroker: if "BUSYGROUP" not in str(e): raise + @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,)) async def read_group( self, stream: str, @@ -99,6 +108,7 @@ class RedisBroker: messages.append(json.loads(payload)) return messages + @retry_async(max_retries=2, base_delay=0.5) async def ping(self) -> bool: """Ping the Redis server; return True if reachable.""" return await self._redis.ping() diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index 7a947b3..0f1c66e 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -1,14 +1,18 @@ """Shared configuration settings for the trading platform.""" +from pydantic import SecretStr, field_validator from pydantic_settings import BaseSettings class Settings(BaseSettings): - alpaca_api_key: str = "" - alpaca_api_secret: str = "" + alpaca_api_key: SecretStr = SecretStr("") + alpaca_api_secret: SecretStr = SecretStr("") alpaca_paper: bool = True # Use paper trading by default - redis_url: str = "redis://localhost:6379" - database_url: str = "postgresql://trading:trading@localhost:5432/trading" + redis_url: SecretStr = SecretStr("redis://localhost:6379") + database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading") + db_pool_size: int = 20 + db_max_overflow: int = 10 + db_pool_recycle: int = 3600 log_level: str = "INFO" risk_max_position_size: float = 0.1 risk_stop_loss_pct: float = 5.0 @@ -27,24 +31,45 @@ class Settings(BaseSettings): risk_max_consecutive_losses: int = 5 risk_loss_pause_minutes: int = 60 dry_run: bool = True - telegram_bot_token: str = "" + telegram_bot_token: SecretStr = SecretStr("") telegram_chat_id: str = "" telegram_enabled: bool = False log_format: str = "json" health_port: int = 8080 - circuit_breaker_threshold: int = 5 - circuit_breaker_timeout: int = 60 metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token + # API security + api_auth_token: SecretStr = SecretStr("") + cors_origins: str = "http://localhost:3000" # News collector - finnhub_api_key: str = "" + finnhub_api_key: SecretStr = SecretStr("") news_poll_interval: int = 300 sentiment_aggregate_interval: int = 900 # Stock selector - selector_candidates_time: str = "15:00" - selector_filter_time: str = "15:15" selector_final_time: str = "15:30" selector_max_picks: int = 3 # LLM - anthropic_api_key: str = "" + anthropic_api_key: SecretStr = SecretStr("") anthropic_model: str = "claude-sonnet-4-20250514" model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} + + @field_validator("risk_max_position_size") + @classmethod + def validate_position_size(cls, v: float) -> float: + if v <= 0 or v > 1: + raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)") + return v + + @field_validator("health_port") + @classmethod + def validate_health_port(cls, v: int) -> int: + if v < 1024 or v > 65535: + raise ValueError("health_port must be between 1024 and 65535") + return v + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + if v.upper() not in valid: + raise ValueError(f"log_level must be one of {valid}") + return v.upper() diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 9cc8686..8fee000 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -3,26 +3,25 @@ import json import uuid from contextlib import asynccontextmanager -from datetime import datetime, date, timedelta, timezone +from datetime import UTC, date, datetime, timedelta from decimal import Decimal -from typing import Optional from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from shared.models import Candle, Signal, Order, OrderStatus, NewsItem -from shared.sentiment_models import SymbolScore, MarketSentiment +from shared.models import Candle, NewsItem, Order, OrderStatus, Signal from shared.sa_models import ( Base, CandleRow, - SignalRow, + MarketSentimentRow, + NewsItemRow, OrderRow, PortfolioSnapshotRow, - NewsItemRow, - SymbolScoreRow, - MarketSentimentRow, + SignalRow, StockSelectionRow, + SymbolScoreRow, ) +from shared.sentiment_models import MarketSentiment, SymbolScore class Database: @@ -36,9 +35,24 @@ class Database: self._engine = None self._session_factory = None - async def connect(self) -> None: + async def connect( + self, + pool_size: int = 20, + max_overflow: int = 10, + pool_recycle: int = 3600, + ) -> None: """Create the async engine, session factory, and all tables.""" - self._engine = create_async_engine(self._database_url) + if self._database_url.startswith("sqlite"): + # SQLite doesn't support pooling options + self._engine = create_async_engine(self._database_url) + else: + self._engine = create_async_engine( + self._database_url, + pool_pre_ping=True, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + ) self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) async with self._engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -98,6 +112,9 @@ class Database: price=signal.price, quantity=signal.quantity, reason=signal.reason, + conviction=signal.conviction, + stop_loss=signal.stop_loss, + take_profit=signal.take_profit, created_at=signal.created_at, ) async with self._session_factory() as session: @@ -134,7 +151,7 @@ class Database: self, order_id: str, status: OrderStatus, - filled_at: Optional[datetime] = None, + filled_at: datetime | None = None, ) -> None: """Update the status (and optionally filled_at) of an order.""" stmt = ( @@ -180,7 +197,7 @@ class Database: total_value=total_value, realized_pnl=realized_pnl, unrealized_pnl=unrealized_pnl, - snapshot_at=datetime.now(timezone.utc), + snapshot_at=datetime.now(UTC), ) session.add(row) await session.commit() @@ -191,7 +208,7 @@ class Database: async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]: """Retrieve recent portfolio snapshots.""" async with self.get_session() as session: - since = datetime.now(timezone.utc) - timedelta(days=days) + since = datetime.now(UTC) - timedelta(days=days) stmt = ( select(PortfolioSnapshotRow) .where(PortfolioSnapshotRow.snapshot_at >= since) @@ -234,7 +251,7 @@ class Database: async def get_recent_news(self, hours: int = 24) -> list[dict]: """Retrieve news items published in the last N hours.""" - since = datetime.now(timezone.utc) - timedelta(hours=hours) + since = datetime.now(UTC) - timedelta(hours=hours) stmt = ( select(NewsItemRow) .where(NewsItemRow.published_at >= since) @@ -352,7 +369,7 @@ class Database: await session.rollback() raise - async def get_latest_market_sentiment(self) -> Optional[dict]: + async def get_latest_market_sentiment(self) -> dict | None: """Retrieve the 'latest' market sentiment row, or None if not found.""" stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") async with self._session_factory() as session: @@ -394,7 +411,7 @@ class Database: reason=reason, key_news=json.dumps(key_news), sentiment_snapshot=json.dumps(sentiment_snapshot), - created_at=datetime.now(timezone.utc), + created_at=datetime.now(UTC), ) async with self._session_factory() as session: try: diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py index 63f93a2..37217a0 100644 --- a/shared/src/shared/events.py +++ b/shared/src/shared/events.py @@ -1,14 +1,14 @@ """Event types and serialization for the trading platform.""" -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel -from shared.models import Candle, Signal, Order, NewsItem +from shared.models import Candle, NewsItem, Order, Signal -class EventType(str, Enum): +class EventType(StrEnum): CANDLE = "CANDLE" SIGNAL = "SIGNAL" ORDER = "ORDER" @@ -88,6 +88,16 @@ class Event: @staticmethod def from_dict(data: dict) -> Any: - event_type = EventType(data["type"]) + """Deserialize a raw dict into the appropriate event type. + + Raises ValueError for malformed or unrecognized event data. + """ + try: + event_type = EventType(data["type"]) + except (KeyError, ValueError) as exc: + raise ValueError(f"Invalid or missing event type in data: {data!r}") from exc cls = _EVENT_TYPE_MAP[event_type] - return cls.from_raw(data) + try: + return cls.from_raw(data) + except KeyError as exc: + raise ValueError(f"Missing required field in {event_type} event data: {exc}") from exc diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py index 7411e8a..a19705b 100644 --- a/shared/src/shared/healthcheck.py +++ b/shared/src/shared/healthcheck.py @@ -3,10 +3,11 @@ from __future__ import annotations import time -from typing import Any, Callable, Awaitable +from collections.abc import Awaitable, Callable +from typing import Any from aiohttp import web -from prometheus_client import CollectorRegistry, REGISTRY, generate_latest, CONTENT_TYPE_LATEST +from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, CollectorRegistry, generate_latest class HealthCheckServer: diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py index cd239f3..6189143 100644 --- a/shared/src/shared/metrics.py +++ b/shared/src/shared/metrics.py @@ -2,7 +2,7 @@ from __future__ import annotations -from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge, Histogram class ServiceMetrics: diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py index a436c03..f357f9f 100644 --- a/shared/src/shared/models.py +++ b/shared/src/shared/models.py @@ -1,25 +1,24 @@ """Shared Pydantic models for the trading platform.""" import uuid +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -from enum import Enum -from typing import Optional +from enum import StrEnum from pydantic import BaseModel, Field, computed_field -class OrderSide(str, Enum): +class OrderSide(StrEnum): BUY = "BUY" SELL = "SELL" -class OrderType(str, Enum): +class OrderType(StrEnum): MARKET = "MARKET" LIMIT = "LIMIT" -class OrderStatus(str, Enum): +class OrderStatus(StrEnum): PENDING = "PENDING" FILLED = "FILLED" CANCELLED = "CANCELLED" @@ -46,9 +45,9 @@ class Signal(BaseModel): quantity: Decimal reason: str conviction: float = 1.0 # 0.0 to 1.0, signal strength/confidence - stop_loss: Optional[Decimal] = None # Price to exit at loss - take_profit: Optional[Decimal] = None # Price to exit at profit - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + stop_loss: Decimal | None = None # Price to exit at loss + take_profit: Decimal | None = None # Price to exit at profit + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) class Order(BaseModel): @@ -60,8 +59,8 @@ class Order(BaseModel): price: Decimal quantity: Decimal status: OrderStatus = OrderStatus.PENDING - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - filled_at: Optional[datetime] = None + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + filled_at: datetime | None = None class Position(BaseModel): @@ -76,7 +75,7 @@ class Position(BaseModel): return self.quantity * (self.current_price - self.avg_entry_price) -class NewsCategory(str, Enum): +class NewsCategory(StrEnum): POLICY = "policy" EARNINGS = "earnings" MACRO = "macro" @@ -89,11 +88,11 @@ class NewsItem(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) source: str headline: str - summary: Optional[str] = None - url: Optional[str] = None + summary: str | None = None + url: str | None = None published_at: datetime symbols: list[str] = [] sentiment: float category: NewsCategory raw_data: dict = {} - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/shared/src/shared/notifier.py b/shared/src/shared/notifier.py index 3d7b6cf..cfc86cd 100644 --- a/shared/src/shared/notifier.py +++ b/shared/src/shared/notifier.py @@ -2,13 +2,13 @@ import asyncio import logging +from collections.abc import Sequence from decimal import Decimal -from typing import Optional, Sequence import aiohttp -from shared.models import Signal, Order, Position -from shared.sentiment_models import SelectedStock, MarketSentiment +from shared.models import Order, Position, Signal +from shared.sentiment_models import MarketSentiment, SelectedStock logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class TelegramNotifier: self._bot_token = bot_token self._chat_id = chat_id self._semaphore = asyncio.Semaphore(1) - self._session: Optional[aiohttp.ClientSession] = None + self._session: aiohttp.ClientSession | None = None @property def enabled(self) -> bool: @@ -113,13 +113,13 @@ class TelegramNotifier: "", "<b>Positions:</b>", ] - for pos in positions: - lines.append( - f" {pos.symbol}: qty={pos.quantity} " - f"entry={pos.avg_entry_price} " - f"current={pos.current_price} " - f"pnl={pos.unrealized_pnl}" - ) + lines.extend( + f" {pos.symbol}: qty={pos.quantity} " + f"entry={pos.avg_entry_price} " + f"current={pos.current_price} " + f"pnl={pos.unrealized_pnl}" + for pos in positions + ) if not positions: lines.append(" No open positions") await self.send("\n".join(lines)) diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py index e43fd21..66225d7 100644 --- a/shared/src/shared/resilience.py +++ b/shared/src/shared/resilience.py @@ -1,29 +1,45 @@ -"""Retry with exponential backoff and circuit breaker utilities.""" +"""Resilience utilities for the trading platform. + +Provides retry, circuit breaker, and timeout primitives using only stdlib. +No external dependencies required. +""" from __future__ import annotations import asyncio -import enum import functools import logging import random import time -from typing import Any, Callable +from collections.abc import Callable +from contextlib import asynccontextmanager +from enum import StrEnum +from typing import Any -logger = logging.getLogger(__name__) + +class _State(StrEnum): + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" -# --------------------------------------------------------------------------- -# retry_with_backoff -# --------------------------------------------------------------------------- +logger = logging.getLogger(__name__) -def retry_with_backoff( +def retry_async( max_retries: int = 3, base_delay: float = 1.0, - max_delay: float = 60.0, + max_delay: float = 30.0, + exclude: tuple[type[BaseException], ...] = (), ) -> Callable: - """Decorator that retries an async function with exponential backoff + jitter.""" + """Decorator: exponential backoff + jitter for async functions. + + Parameters: + max_retries: Maximum number of retry attempts (after the initial call). + base_delay: Base delay in seconds for exponential backoff. + max_delay: Maximum delay cap in seconds. + exclude: Exception types that should NOT be retried (raised immediately). + """ def decorator(func: Callable) -> Callable: @functools.wraps(func) @@ -33,20 +49,21 @@ def retry_with_backoff( try: return await func(*args, **kwargs) except Exception as exc: + if exclude and isinstance(exc, exclude): + raise last_exc = exc if attempt < max_retries: delay = min(base_delay * (2**attempt), max_delay) - jitter = delay * random.uniform(0, 0.5) - total_delay = delay + jitter + jitter_delay = delay * random.uniform(0.5, 1.0) logger.warning( - "Retry %d/%d for %s after error: %s (delay=%.3fs)", + "Retry %d/%d for %s in %.2fs: %s", attempt + 1, max_retries, func.__name__, + jitter_delay, exc, - total_delay, ) - await asyncio.sleep(total_delay) + await asyncio.sleep(jitter_delay) raise last_exc # type: ignore[misc] return wrapper @@ -54,52 +71,65 @@ def retry_with_backoff( return decorator -# --------------------------------------------------------------------------- -# CircuitBreaker -# --------------------------------------------------------------------------- - - -class CircuitState(enum.Enum): - CLOSED = "closed" - OPEN = "open" - HALF_OPEN = "half_open" +class CircuitBreaker: + """Circuit breaker: opens after N consecutive failures, auto-recovers. + States: closed -> open -> half_open -> closed -class CircuitBreaker: - """Simple circuit breaker implementation.""" + Parameters: + failure_threshold: Number of consecutive failures before opening. + cooldown: Seconds to wait before allowing a half-open probe. + """ - def __init__( - self, - failure_threshold: int = 5, - recovery_timeout: float = 60.0, - ) -> None: + def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None: self._failure_threshold = failure_threshold - self._recovery_timeout = recovery_timeout - self._failure_count: int = 0 - self._state = CircuitState.CLOSED + self._cooldown = cooldown + self._failures = 0 + self._state = _State.CLOSED self._opened_at: float = 0.0 - @property - def state(self) -> CircuitState: - return self._state - - def allow_request(self) -> bool: - if self._state == CircuitState.CLOSED: - return True - if self._state == CircuitState.OPEN: - if time.monotonic() - self._opened_at >= self._recovery_timeout: - self._state = CircuitState.HALF_OPEN - return True - return False - # HALF_OPEN - return True - - def record_success(self) -> None: - self._failure_count = 0 - self._state = CircuitState.CLOSED - - def record_failure(self) -> None: - self._failure_count += 1 - if self._failure_count >= self._failure_threshold: - self._state = CircuitState.OPEN - self._opened_at = time.monotonic() + async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any: + """Execute func through the breaker.""" + if self._state == _State.OPEN: + if time.monotonic() - self._opened_at >= self._cooldown: + self._state = _State.HALF_OPEN + else: + raise RuntimeError("Circuit breaker is open") + + try: + result = await func(*args, **kwargs) + except Exception: + self._failures += 1 + if self._state == _State.HALF_OPEN: + self._state = _State.OPEN + self._opened_at = time.monotonic() + logger.error( + "Circuit breaker re-opened after half-open probe failure (threshold=%d)", + self._failure_threshold, + ) + elif self._failures >= self._failure_threshold: + self._state = _State.OPEN + self._opened_at = time.monotonic() + logger.error( + "Circuit breaker opened after %d consecutive failures", + self._failures, + ) + raise + + # Success: reset + self._failures = 0 + self._state = _State.CLOSED + return result + + +@asynccontextmanager +async def async_timeout(seconds: float): + """Async context manager wrapping asyncio.timeout(). + + Raises TimeoutError with a descriptive message on timeout. + """ + try: + async with asyncio.timeout(seconds): + yield + except TimeoutError: + raise TimeoutError(f"Operation timed out after {seconds}s") from None diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py index 1bd92c2..b70a6c4 100644 --- a/shared/src/shared/sa_models.py +++ b/shared/src/shared/sa_models.py @@ -35,6 +35,9 @@ class SignalRow(Base): price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) reason: Mapped[str | None] = mapped_column(Text) + conviction: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="1.0") + stop_loss: Mapped[Decimal | None] = mapped_column(Numeric) + take_profit: Mapped[Decimal | None] = mapped_column(Numeric) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py index 449eb76..c56da3e 100644 --- a/shared/src/shared/sentiment.py +++ b/shared/src/shared/sentiment.py @@ -1,41 +1,10 @@ -"""Market sentiment data.""" +"""Market sentiment aggregation.""" -import logging -from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime +from typing import ClassVar from shared.sentiment_models import SymbolScore -logger = logging.getLogger(__name__) - - -@dataclass -class SentimentData: - """Aggregated sentiment snapshot.""" - - fear_greed_value: int | None = None - fear_greed_label: str | None = None - news_sentiment: float | None = None - news_count: int = 0 - exchange_netflow: float | None = None - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - @property - def should_buy(self) -> bool: - if self.fear_greed_value is not None and self.fear_greed_value > 70: - return False - if self.news_sentiment is not None and self.news_sentiment < -0.3: - return False - return True - - @property - def should_block(self) -> bool: - if self.fear_greed_value is not None and self.fear_greed_value > 80: - return True - if self.news_sentiment is not None and self.news_sentiment < -0.5: - return True - return False - def _safe_avg(values: list[float]) -> float: if not values: @@ -46,9 +15,9 @@ def _safe_avg(values: list[float]) -> float: class SentimentAggregator: """Aggregates per-news sentiment into per-symbol scores.""" - WEIGHTS = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2} + WEIGHTS: ClassVar[dict[str, float]] = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2} - CATEGORY_MAP = { + CATEGORY_MAP: ClassVar[dict[str, str]] = { "earnings": "news", "macro": "news", "social": "social", diff --git a/shared/src/shared/sentiment_models.py b/shared/src/shared/sentiment_models.py index a009601..ac06c20 100644 --- a/shared/src/shared/sentiment_models.py +++ b/shared/src/shared/sentiment_models.py @@ -1,7 +1,6 @@ """Sentiment scoring and stock selection models.""" from datetime import datetime -from typing import Optional from pydantic import BaseModel @@ -22,7 +21,7 @@ class SymbolScore(BaseModel): class MarketSentiment(BaseModel): fear_greed: int fear_greed_label: str - vix: Optional[float] = None + vix: float | None = None fed_stance: str market_regime: str updated_at: datetime @@ -39,6 +38,6 @@ class SelectedStock(BaseModel): class Candidate(BaseModel): symbol: str source: str - direction: Optional[OrderSide] = None + direction: OrderSide | None = None score: float reason: str diff --git a/shared/src/shared/shutdown.py b/shared/src/shared/shutdown.py new file mode 100644 index 0000000..4ed9aa7 --- /dev/null +++ b/shared/src/shared/shutdown.py @@ -0,0 +1,30 @@ +"""Graceful shutdown utilities for services.""" + +import asyncio +import logging +import signal + +logger = logging.getLogger(__name__) + + +class GracefulShutdown: + """Manages graceful shutdown via SIGTERM/SIGINT signals.""" + + def __init__(self) -> None: + self._event = asyncio.Event() + + @property + def is_shutting_down(self) -> bool: + return self._event.is_set() + + async def wait(self) -> None: + await self._event.wait() + + def trigger(self) -> None: + logger.info("shutdown_signal_received") + self._event.set() + + def install_handlers(self) -> None: + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, self.trigger) diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py index 080b7c4..55a2b24 100644 --- a/shared/tests/test_alpaca.py +++ b/shared/tests/test_alpaca.py @@ -1,7 +1,9 @@ """Tests for Alpaca API client.""" -import pytest from unittest.mock import AsyncMock, MagicMock + +import pytest + from shared.alpaca import AlpacaClient diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py index eb1582d..5636611 100644 --- a/shared/tests/test_broker.py +++ b/shared/tests/test_broker.py @@ -1,10 +1,11 @@ """Tests for the Redis broker.""" -import pytest import json -import redis from unittest.mock import AsyncMock, patch +import pytest +import redis + @pytest.mark.asyncio async def test_broker_publish(): diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py new file mode 100644 index 0000000..9376dc6 --- /dev/null +++ b/shared/tests/test_config_validation.py @@ -0,0 +1,29 @@ +"""Tests for config validation.""" + +import pytest +from pydantic import ValidationError + +from shared.config import Settings + + +class TestConfigValidation: + def test_valid_defaults(self): + settings = Settings() + assert settings.risk_max_position_size == 0.1 + + def test_invalid_position_size(self): + with pytest.raises(ValidationError, match="risk_max_position_size"): + Settings(risk_max_position_size=-0.1) + + def test_invalid_health_port(self): + with pytest.raises(ValidationError, match="health_port"): + Settings(health_port=80) + + def test_invalid_log_level(self): + with pytest.raises(ValidationError, match="log_level"): + Settings(log_level="INVALID") + + def test_secret_fields_masked(self): + settings = Settings(alpaca_api_key="my-secret-key") + assert "my-secret-key" not in repr(settings) + assert settings.alpaca_api_key.get_secret_value() == "my-secret-key" diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 239ee64..b44a713 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,10 +1,11 @@ """Tests for the SQLAlchemy async database layer.""" -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch +import pytest + def make_candle(): from shared.models import Candle @@ -12,7 +13,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -22,7 +23,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( id="sig-1", @@ -32,12 +33,12 @@ def make_signal(): price=Decimal("50000"), quantity=Decimal("0.1"), reason="Golden cross", - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) def make_order(): - from shared.models import Order, OrderSide, OrderType, OrderStatus + from shared.models import Order, OrderSide, OrderStatus, OrderType return Order( id="ord-1", @@ -48,7 +49,7 @@ def make_order(): price=Decimal("50000"), quantity=Decimal("0.1"), status=OrderStatus.PENDING, - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) @@ -101,6 +102,54 @@ class TestDatabaseConnect: mock_create.assert_called_once() @pytest.mark.asyncio + async def test_connect_passes_pool_params_for_postgres(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800) + mock_create.assert_called_once_with( + "postgresql+asyncpg://host/db", + pool_pre_ping=True, + pool_size=5, + max_overflow=3, + pool_recycle=1800, + ) + + @pytest.mark.asyncio + async def test_connect_skips_pool_params_for_sqlite(self): + from shared.db import Database + + db = Database("sqlite+aiosqlite:///test.db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db") + + @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database @@ -211,7 +260,7 @@ class TestUpdateOrderStatus: db._session_factory = MagicMock(return_value=mock_session) - filled = datetime(2024, 1, 2, tzinfo=timezone.utc) + filled = datetime(2024, 1, 2, tzinfo=UTC) await db.update_order_status("ord-1", OrderStatus.FILLED, filled) mock_session.execute.assert_awaited_once() @@ -230,7 +279,7 @@ class TestGetCandles: mock_row._mapping = { "symbol": "AAPL", "timeframe": "1m", - "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), + "open_time": datetime(2024, 1, 1, tzinfo=UTC), "open": Decimal("50000"), "high": Decimal("51000"), "low": Decimal("49500"), @@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots: mock_row.total_value = Decimal("10000") mock_row.realized_pnl = Decimal("0") mock_row.unrealized_pnl = Decimal("500") - mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_row] diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py index a2c9140..c184bed 100644 --- a/shared/tests/test_db_news.py +++ b/shared/tests/test_db_news.py @@ -1,11 +1,12 @@ """Tests for database news/sentiment methods. Uses in-memory SQLite.""" +from datetime import UTC, date, datetime + import pytest -from datetime import datetime, date, timezone from shared.db import Database -from shared.models import NewsItem, NewsCategory -from shared.sentiment_models import SymbolScore, MarketSentiment +from shared.models import NewsCategory, NewsItem +from shared.sentiment_models import MarketSentiment, SymbolScore @pytest.fixture @@ -20,7 +21,7 @@ async def test_insert_and_get_news_items(db): item = NewsItem( source="finnhub", headline="AAPL earnings beat", - published_at=datetime(2026, 4, 2, 12, 0, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), sentiment=0.8, category=NewsCategory.EARNINGS, symbols=["AAPL"], @@ -40,7 +41,7 @@ async def test_upsert_symbol_score(db): policy_score=0.0, filing_score=0.2, composite=0.3, - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) await db.upsert_symbol_score(score) scores = await db.get_top_symbol_scores(limit=5) @@ -55,7 +56,7 @@ async def test_upsert_market_sentiment(db): vix=18.2, fed_stance="neutral", market_regime="neutral", - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) await db.upsert_market_sentiment(ms) result = await db.get_latest_market_sentiment() diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py index 6077d93..1ccd904 100644 --- a/shared/tests/test_events.py +++ b/shared/tests/test_events.py @@ -1,7 +1,7 @@ """Tests for shared event types.""" +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone def make_candle(): @@ -10,7 +10,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -20,7 +20,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( strategy="test", @@ -59,7 +59,7 @@ def test_candle_event_deserialize(): def test_signal_event_serialize(): """Test SignalEvent serializes to dict correctly.""" - from shared.events import SignalEvent, EventType + from shared.events import EventType, SignalEvent signal = make_signal() event = SignalEvent(data=signal) @@ -71,7 +71,7 @@ def test_signal_event_serialize(): def test_event_from_dict_dispatch(): """Test Event.from_dict dispatches to correct class.""" - from shared.events import Event, CandleEvent, SignalEvent + from shared.events import CandleEvent, Event, SignalEvent candle = make_candle() event = CandleEvent(data=candle) diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index 04098ce..40bb791 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -1,8 +1,8 @@ """Tests for shared models and settings.""" import os +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import patch @@ -12,8 +12,11 @@ def test_settings_defaults(): with patch.dict(os.environ, {}, clear=False): settings = Settings() - assert settings.redis_url == "redis://localhost:6379" - assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" + assert settings.redis_url.get_secret_value() == "redis://localhost:6379" + assert ( + settings.database_url.get_secret_value() + == "postgresql://trading:trading@localhost:5432/trading" + ) assert settings.log_level == "INFO" assert settings.risk_max_position_size == 0.1 assert settings.risk_stop_loss_pct == 5.0 @@ -25,7 +28,7 @@ def test_candle_creation(): """Test Candle model creation.""" from shared.models import Candle - now = datetime.now(timezone.utc) + now = datetime.now(UTC) candle = Candle( symbol="AAPL", timeframe="1m", @@ -47,7 +50,7 @@ def test_candle_creation(): def test_signal_creation(): """Test Signal model creation.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi_strategy", @@ -69,9 +72,10 @@ def test_signal_creation(): def test_order_creation(): """Test Order model creation with defaults.""" - from shared.models import Order, OrderSide, OrderType, OrderStatus import uuid + from shared.models import Order, OrderSide, OrderStatus, OrderType + signal_id = str(uuid.uuid4()) order = Order( signal_id=signal_id, @@ -90,7 +94,7 @@ def test_order_creation(): def test_signal_conviction_default(): """Test Signal defaults for conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", @@ -107,7 +111,7 @@ def test_signal_conviction_default(): def test_signal_with_stops(): """Test Signal with explicit conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py index 384796a..f748d8a 100644 --- a/shared/tests/test_news_events.py +++ b/shared/tests/test_news_events.py @@ -1,16 +1,16 @@ """Tests for NewsEvent.""" -from datetime import datetime, timezone +from datetime import UTC, datetime +from shared.events import Event, EventType, NewsEvent from shared.models import NewsCategory, NewsItem -from shared.events import NewsEvent, EventType, Event def test_news_event_to_dict(): item = NewsItem( source="finnhub", headline="Test", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.5, category=NewsCategory.MACRO, ) diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py index 6c81369..cc98a56 100644 --- a/shared/tests/test_notifier.py +++ b/shared/tests/test_notifier.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position +from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal from shared.notifier import TelegramNotifier diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index e287777..e0781af 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -1,139 +1,176 @@ -"""Tests for retry with backoff and circuit breaker.""" +"""Tests for shared.resilience module.""" -import time +import asyncio import pytest -from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff +from shared.resilience import CircuitBreaker, async_timeout, retry_async +# --- retry_async tests --- -# --------------------------------------------------------------------------- -# retry_with_backoff tests -# --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_retry_succeeds_first_try(): +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def succeed(): + @retry_async() + async def fn(): nonlocal call_count call_count += 1 return "ok" - result = await succeed() + result = await fn() assert result == "ok" assert call_count == 1 -@pytest.mark.asyncio -async def test_retry_succeeds_after_failures(): +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def flaky(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 if call_count < 3: - raise ValueError("not yet") + raise RuntimeError("transient") return "recovered" - result = await flaky() + result = await fn() assert result == "recovered" assert call_count == 3 -@pytest.mark.asyncio -async def test_retry_raises_after_max_retries(): +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def always_fail(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 - raise RuntimeError("permanent") + raise ValueError("permanent") - with pytest.raises(RuntimeError, match="permanent"): - await always_fail() - # 1 initial + 3 retries = 4 calls + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls assert call_count == 4 -@pytest.mark.asyncio -async def test_retry_respects_max_delay(): - """Backoff should be capped at max_delay.""" +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 - @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) - async def always_fail(): - raise RuntimeError("fail") + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") - start = time.monotonic() - with pytest.raises(RuntimeError): - await always_fail() - elapsed = time.monotonic() - start - # With max_delay=0.02 and 2 retries, total delay should be small - assert elapsed < 0.5 + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 -# --------------------------------------------------------------------------- -# CircuitBreaker tests -# --------------------------------------------------------------------------- +# --- CircuitBreaker tests --- -def test_circuit_starts_closed(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") -def test_circuit_opens_after_threshold(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) for _ in range(3): - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +async def test_half_open_reopens_on_failure(): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def always_fail(): + raise ConnectionError("fail") -def test_circuit_rejects_when_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(always_fail) + # Wait for cooldown + await asyncio.sleep(0.1) -def test_circuit_half_open_after_timeout(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN + # Half-open probe should fail and re-open + with pytest.raises(ConnectionError): + await cb.call(always_fail) - time.sleep(0.06) - assert cb.allow_request() is True - assert cb.state == CircuitState.HALF_OPEN + # Should be open again (no cooldown wait) + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(always_fail) -def test_circuit_closes_on_success(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN +# --- async_timeout tests --- - time.sleep(0.06) - cb.allow_request() # triggers HALF_OPEN - assert cb.state == CircuitState.HALF_OPEN - cb.record_success() - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 -def test_circuit_reopens_on_failure_in_half_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - time.sleep(0.06) - cb.allow_request() # HALF_OPEN - cb.record_failure() - assert cb.state == CircuitState.OPEN +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py index dc6355e..c9311dd 100644 --- a/shared/tests/test_sa_models.py +++ b/shared/tests/test_sa_models.py @@ -72,6 +72,9 @@ class TestSignalRow: "price", "quantity", "reason", + "conviction", + "stop_loss", + "take_profit", "created_at", } assert expected == cols @@ -124,44 +127,6 @@ class TestOrderRow: assert fk_cols == {"signal_id": "signals.id"} -class TestTradeRow: - def test_table_name(self): - from shared.sa_models import TradeRow - - assert TradeRow.__tablename__ == "trades" - - def test_columns(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - cols = {c.key for c in mapper.column_attrs} - expected = { - "id", - "order_id", - "symbol", - "side", - "price", - "quantity", - "fee", - "traded_at", - } - assert expected == cols - - def test_primary_key(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - pk_cols = [c.name for c in mapper.mapper.primary_key] - assert pk_cols == ["id"] - - def test_order_id_foreign_key(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys} - assert fk_cols == {"order_id": "orders.id"} - - class TestPositionRow: def test_table_name(self): from shared.sa_models import PositionRow @@ -233,11 +198,3 @@ class TestStatusDefault: status_col = table.c.status assert status_col.server_default is not None assert status_col.server_default.arg == "PENDING" - - def test_trade_fee_server_default(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fee_col = table.c.fee - assert fee_col.server_default is not None - assert fee_col.server_default.arg == "0" diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py index 91e6d4a..dc2d026 100644 --- a/shared/tests/test_sa_news_models.py +++ b/shared/tests/test_sa_news_models.py @@ -1,6 +1,6 @@ """Tests for news-related SQLAlchemy models.""" -from shared.sa_models import NewsItemRow, SymbolScoreRow, MarketSentimentRow, StockSelectionRow +from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow def test_news_item_row_tablename(): diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py deleted file mode 100644 index 9bd8ea3..0000000 --- a/shared/tests/test_sentiment.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Tests for market sentiment module.""" - -from shared.sentiment import SentimentData - - -def test_sentiment_should_buy_default_no_data(): - s = SentimentData() - assert s.should_buy is True - assert s.should_block is False - - -def test_sentiment_should_buy_low_fear_greed(): - s = SentimentData(fear_greed_value=15) - assert s.should_buy is True - - -def test_sentiment_should_not_buy_on_greed(): - s = SentimentData(fear_greed_value=75) - assert s.should_buy is False - - -def test_sentiment_should_not_buy_negative_news(): - s = SentimentData(news_sentiment=-0.4) - assert s.should_buy is False - - -def test_sentiment_should_buy_positive_news(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.3) - assert s.should_buy is True - - -def test_sentiment_should_block_extreme_greed(): - s = SentimentData(fear_greed_value=85) - assert s.should_block is True - - -def test_sentiment_should_block_very_negative_news(): - s = SentimentData(news_sentiment=-0.6) - assert s.should_block is True - - -def test_sentiment_no_block_on_neutral(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.0) - assert s.should_block is False diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py index a99c711..9193785 100644 --- a/shared/tests/test_sentiment_aggregator.py +++ b/shared/tests/test_sentiment_aggregator.py @@ -1,7 +1,9 @@ """Tests for sentiment aggregator.""" +from datetime import UTC, datetime, timedelta + import pytest -from datetime import datetime, timezone, timedelta + from shared.sentiment import SentimentAggregator @@ -12,25 +14,25 @@ def aggregator(): def test_freshness_decay_recent(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now, now) == 1.0 def test_freshness_decay_3_hours(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7 def test_freshness_decay_12_hours(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3 def test_freshness_decay_old(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now - timedelta(days=2), now) == 0.0 @@ -44,7 +46,7 @@ def test_compute_composite(): def test_aggregate_news_by_symbol(aggregator): - now = datetime.now(timezone.utc) + now = datetime.now(UTC) news_items = [ {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now}, { @@ -64,7 +66,7 @@ def test_aggregate_news_by_symbol(aggregator): def test_aggregate_empty(aggregator): - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert aggregator.aggregate([], now) == {} diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py index 25fc371..e00ffa6 100644 --- a/shared/tests/test_sentiment_models.py +++ b/shared/tests/test_sentiment_models.py @@ -1,16 +1,16 @@ """Tests for news and sentiment models.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from shared.models import NewsCategory, NewsItem, OrderSide -from shared.sentiment_models import SymbolScore, MarketSentiment, SelectedStock, Candidate +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore def test_news_item_defaults(): item = NewsItem( source="finnhub", headline="Test headline", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.5, category=NewsCategory.MACRO, ) @@ -25,7 +25,7 @@ def test_news_item_with_symbols(): item = NewsItem( source="rss", headline="AAPL earnings beat", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.8, category=NewsCategory.EARNINGS, symbols=["AAPL"], @@ -52,7 +52,7 @@ def test_symbol_score(): policy_score=0.0, filing_score=0.2, composite=0.3, - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) assert score.symbol == "AAPL" assert score.composite == 0.3 @@ -65,7 +65,7 @@ def test_market_sentiment(): vix=32.5, fed_stance="hawkish", market_regime="risk_off", - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) assert ms.market_regime == "risk_off" assert ms.vix == 32.5 @@ -77,7 +77,7 @@ def test_market_sentiment_no_vix(): fear_greed_label="Neutral", fed_stance="neutral", market_regime="neutral", - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) assert ms.vix is None diff --git a/tests/edge_cases/test_empty_data.py b/tests/edge_cases/test_empty_data.py index bfefc95..876a640 100644 --- a/tests/edge_cases/test_empty_data.py +++ b/tests/edge_cases/test_empty_data.py @@ -12,13 +12,14 @@ sys.path.insert( 0, str(Path(__file__).resolve().parents[2] / "services" / "portfolio-manager" / "src") ) -from shared.models import Signal, OrderSide from backtester.engine import BacktestEngine from backtester.metrics import compute_detailed_metrics -from portfolio_manager.portfolio import PortfolioTracker from order_executor.risk_manager import RiskManager +from portfolio_manager.portfolio import PortfolioTracker from strategies.rsi_strategy import RsiStrategy +from shared.models import OrderSide, Signal + class TestBacktestEngineEmptyCandles: """BacktestEngine.run([]) should return valid result with 0 trades.""" diff --git a/tests/edge_cases/test_extreme_values.py b/tests/edge_cases/test_extreme_values.py index b375d5e..8ec3b77 100644 --- a/tests/edge_cases/test_extreme_values.py +++ b/tests/edge_cases/test_extreme_values.py @@ -1,7 +1,7 @@ """Tests for extreme value edge cases.""" import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from pathlib import Path @@ -9,19 +9,20 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strat sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "order-executor" / "src")) -from shared.models import Candle, Signal, OrderSide -from strategies.rsi_strategy import RsiStrategy -from strategies.vwap_strategy import VwapStrategy -from strategies.bollinger_strategy import BollingerStrategy from backtester.engine import BacktestEngine from backtester.simulator import OrderSimulator from order_executor.risk_manager import RiskManager +from strategies.bollinger_strategy import BollingerStrategy +from strategies.rsi_strategy import RsiStrategy +from strategies.vwap_strategy import VwapStrategy + +from shared.models import Candle, OrderSide, Signal def _candle(close: str, volume: str = "1000", idx: int = 0) -> Candle: from datetime import timedelta - base = datetime(2025, 1, 1, tzinfo=timezone.utc) + base = datetime(2025, 1, 1, tzinfo=UTC) return Candle( symbol="AAPL", timeframe="1h", diff --git a/tests/edge_cases/test_strategy_reset.py b/tests/edge_cases/test_strategy_reset.py index 6e9b956..13ed4da 100644 --- a/tests/edge_cases/test_strategy_reset.py +++ b/tests/edge_cases/test_strategy_reset.py @@ -1,21 +1,22 @@ """Tests that strategy reset() properly clears internal state.""" import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src")) -from shared.models import Candle -from strategies.rsi_strategy import RsiStrategy -from strategies.grid_strategy import GridStrategy -from strategies.macd_strategy import MacdStrategy from strategies.bollinger_strategy import BollingerStrategy from strategies.ema_crossover_strategy import EmaCrossoverStrategy -from strategies.vwap_strategy import VwapStrategy +from strategies.grid_strategy import GridStrategy +from strategies.macd_strategy import MacdStrategy +from strategies.rsi_strategy import RsiStrategy from strategies.volume_profile_strategy import VolumeProfileStrategy +from strategies.vwap_strategy import VwapStrategy + +from shared.models import Candle def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]: @@ -28,7 +29,7 @@ def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]: Candle( symbol="AAPL", timeframe="1h", - open_time=datetime(2025, 1, 1, i % 24, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, i % 24, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price + 1)), low=Decimal(str(price - 1)), @@ -57,7 +58,7 @@ class TestRsiReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price @@ -71,7 +72,7 @@ class TestGridReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price @@ -84,7 +85,7 @@ class TestMacdReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price @@ -97,7 +98,7 @@ class TestBollingerReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price @@ -110,7 +111,7 @@ class TestEmaCrossoverReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price @@ -123,7 +124,7 @@ class TestVwapReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price @@ -136,6 +137,6 @@ class TestVolumeProfileReset: strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) - for s1, s2 in zip(signals1, signals2): + for s1, s2 in zip(signals1, signals2, strict=True): assert s1.side == s2.side assert s1.price == s2.price diff --git a/tests/edge_cases/test_zero_volume.py b/tests/edge_cases/test_zero_volume.py index ba2c133..df247cc 100644 --- a/tests/edge_cases/test_zero_volume.py +++ b/tests/edge_cases/test_zero_volume.py @@ -1,21 +1,22 @@ """Tests for strategies handling zero-volume candles gracefully.""" import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src")) -from shared.models import Candle -from strategies.vwap_strategy import VwapStrategy -from strategies.volume_profile_strategy import VolumeProfileStrategy from strategies.rsi_strategy import RsiStrategy +from strategies.volume_profile_strategy import VolumeProfileStrategy +from strategies.vwap_strategy import VwapStrategy + +from shared.models import Candle def _candle(close: str, volume: str = "0", idx: int = 0) -> Candle: - base = datetime(2025, 1, 1, tzinfo=timezone.utc) + base = datetime(2025, 1, 1, tzinfo=UTC) from datetime import timedelta return Candle( diff --git a/tests/integration/test_backtest_end_to_end.py b/tests/integration/test_backtest_end_to_end.py index 4cc0b12..fbc0a24 100644 --- a/tests/integration/test_backtest_end_to_end.py +++ b/tests/integration/test_backtest_end_to_end.py @@ -9,19 +9,20 @@ sys.path.insert( sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src")) +from datetime import UTC, datetime, timedelta from decimal import Decimal -from datetime import datetime, timedelta, timezone -from shared.models import Candle from backtester.engine import BacktestEngine +from shared.models import Candle + def _generate_candles(prices: list[float], symbol="AAPL") -> list[Candle]: return [ Candle( symbol=symbol, timeframe="1h", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i), + open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i), open=Decimal(str(p)), high=Decimal(str(p + 100)), low=Decimal(str(p - 100)), diff --git a/tests/integration/test_order_execution_flow.py b/tests/integration/test_order_execution_flow.py index dcbc498..2beb388 100644 --- a/tests/integration/test_order_execution_flow.py +++ b/tests/integration/test_order_execution_flow.py @@ -5,14 +5,15 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "order-executor" / "src")) -import pytest from decimal import Decimal from unittest.mock import AsyncMock -from shared.models import Signal, OrderSide, OrderStatus +import pytest from order_executor.executor import OrderExecutor from order_executor.risk_manager import RiskManager +from shared.models import OrderSide, OrderStatus, Signal + @pytest.mark.asyncio async def test_signal_to_order_flow(): diff --git a/tests/integration/test_portfolio_tracking_flow.py b/tests/integration/test_portfolio_tracking_flow.py index b20275a..d91a265 100644 --- a/tests/integration/test_portfolio_tracking_flow.py +++ b/tests/integration/test_portfolio_tracking_flow.py @@ -9,9 +9,10 @@ sys.path.insert( from decimal import Decimal -from shared.models import Order, OrderSide, OrderType, OrderStatus from portfolio_manager.portfolio import PortfolioTracker +from shared.models import Order, OrderSide, OrderStatus, OrderType + def test_portfolio_tracks_buy_sell_cycle(): """Buy then sell should update position and reset on full sell.""" diff --git a/tests/integration/test_strategy_signal_flow.py b/tests/integration/test_strategy_signal_flow.py index 6b048fb..3f7ec35 100644 --- a/tests/integration/test_strategy_signal_flow.py +++ b/tests/integration/test_strategy_signal_flow.py @@ -8,15 +8,16 @@ sys.path.insert( ) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine")) -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock -from shared.models import Candle -from shared.events import CandleEvent +import pytest from strategy_engine.engine import StrategyEngine +from shared.events import CandleEvent +from shared.models import Candle + @pytest.fixture def candles(): @@ -28,7 +29,7 @@ def candles(): Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, i, 0, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, i, 0, tzinfo=UTC), open=price, high=price + 1, low=price - 1, |
