summaryrefslogtreecommitdiff
path: root/services
diff options
context:
space:
mode:
Diffstat (limited to 'services')
-rw-r--r--services/api/Dockerfile15
-rw-r--r--services/api/pyproject.toml6
-rw-r--r--services/api/src/trading_api/dependencies/__init__.py0
-rw-r--r--services/api/src/trading_api/dependencies/auth.py29
-rw-r--r--services/api/src/trading_api/main.py50
-rw-r--r--services/api/src/trading_api/routers/orders.py29
-rw-r--r--services/api/src/trading_api/routers/portfolio.py22
-rw-r--r--services/api/src/trading_api/routers/strategies.py7
-rw-r--r--services/api/tests/test_api.py1
-rw-r--r--services/api/tests/test_orders_router.py6
-rw-r--r--services/api/tests/test_portfolio_router.py10
-rw-r--r--services/backtester/Dockerfile9
-rw-r--r--services/backtester/pyproject.toml2
-rw-r--r--services/backtester/src/backtester/config.py2
-rw-r--r--services/backtester/src/backtester/engine.py5
-rw-r--r--services/backtester/src/backtester/main.py6
-rw-r--r--services/backtester/src/backtester/metrics.py2
-rw-r--r--services/backtester/src/backtester/simulator.py19
-rw-r--r--services/backtester/src/backtester/walk_forward.py4
-rw-r--r--services/backtester/tests/test_engine.py9
-rw-r--r--services/backtester/tests/test_metrics.py9
-rw-r--r--services/backtester/tests/test_simulator.py13
-rw-r--r--services/backtester/tests/test_walk_forward.py12
-rw-r--r--services/data-collector/Dockerfile9
-rw-r--r--services/data-collector/src/data_collector/main.py31
-rw-r--r--services/data-collector/tests/test_storage.py15
-rw-r--r--services/news-collector/Dockerfile17
-rw-r--r--services/news-collector/pyproject.toml20
-rw-r--r--services/news-collector/src/news_collector/__init__.py1
-rw-r--r--services/news-collector/src/news_collector/collectors/__init__.py1
-rw-r--r--services/news-collector/src/news_collector/collectors/base.py18
-rw-r--r--services/news-collector/src/news_collector/collectors/fear_greed.py62
-rw-r--r--services/news-collector/src/news_collector/collectors/fed.py119
-rw-r--r--services/news-collector/src/news_collector/collectors/finnhub.py88
-rw-r--r--services/news-collector/src/news_collector/collectors/reddit.py97
-rw-r--r--services/news-collector/src/news_collector/collectors/rss.py105
-rw-r--r--services/news-collector/src/news_collector/collectors/sec_edgar.py98
-rw-r--r--services/news-collector/src/news_collector/collectors/truth_social.py86
-rw-r--r--services/news-collector/src/news_collector/config.py7
-rw-r--r--services/news-collector/src/news_collector/main.py204
-rw-r--r--services/news-collector/tests/__init__.py0
-rw-r--r--services/news-collector/tests/test_fear_greed.py49
-rw-r--r--services/news-collector/tests/test_fed.py38
-rw-r--r--services/news-collector/tests/test_finnhub.py67
-rw-r--r--services/news-collector/tests/test_main.py41
-rw-r--r--services/news-collector/tests/test_reddit.py64
-rw-r--r--services/news-collector/tests/test_rss.py47
-rw-r--r--services/news-collector/tests/test_sec_edgar.py58
-rw-r--r--services/news-collector/tests/test_truth_social.py42
-rw-r--r--services/order-executor/Dockerfile9
-rw-r--r--services/order-executor/src/order_executor/executor.py16
-rw-r--r--services/order-executor/src/order_executor/main.py45
-rw-r--r--services/order-executor/src/order_executor/risk_manager.py26
-rw-r--r--services/order-executor/tests/test_executor.py4
-rw-r--r--services/order-executor/tests/test_risk_manager.py48
-rw-r--r--services/portfolio-manager/Dockerfile9
-rw-r--r--services/portfolio-manager/src/portfolio_manager/main.py43
-rw-r--r--services/portfolio-manager/tests/test_portfolio.py27
-rw-r--r--services/portfolio-manager/tests/test_snapshot.py5
-rw-r--r--services/strategy-engine/Dockerfile9
-rw-r--r--services/strategy-engine/pyproject.toml6
-rw-r--r--services/strategy-engine/src/strategy_engine/config.py2
-rw-r--r--services/strategy-engine/src/strategy_engine/engine.py8
-rw-r--r--services/strategy-engine/src/strategy_engine/main.py85
-rw-r--r--services/strategy-engine/src/strategy_engine/plugin_loader.py1
-rw-r--r--services/strategy-engine/src/strategy_engine/stock_selector.py418
-rw-r--r--services/strategy-engine/strategies/base.py5
-rw-r--r--services/strategy-engine/strategies/bollinger_strategy.py2
-rw-r--r--services/strategy-engine/strategies/combined_strategy.py2
-rw-r--r--services/strategy-engine/strategies/ema_crossover_strategy.py2
-rw-r--r--services/strategy-engine/strategies/grid_strategy.py5
-rw-r--r--services/strategy-engine/strategies/indicators/__init__.py16
-rw-r--r--services/strategy-engine/strategies/indicators/momentum.py2
-rw-r--r--services/strategy-engine/strategies/indicators/trend.py2
-rw-r--r--services/strategy-engine/strategies/indicators/volatility.py2
-rw-r--r--services/strategy-engine/strategies/indicators/volume.py2
-rw-r--r--services/strategy-engine/strategies/macd_strategy.py2
-rw-r--r--services/strategy-engine/strategies/moc_strategy.py4
-rw-r--r--services/strategy-engine/strategies/rsi_strategy.py2
-rw-r--r--services/strategy-engine/strategies/volume_profile_strategy.py4
-rw-r--r--services/strategy-engine/strategies/vwap_strategy.py4
-rw-r--r--services/strategy-engine/tests/conftest.py5
-rw-r--r--services/strategy-engine/tests/test_base_filters.py7
-rw-r--r--services/strategy-engine/tests/test_bollinger_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_combined_strategy.py11
-rw-r--r--services/strategy-engine/tests/test_ema_crossover_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_engine.py8
-rw-r--r--services/strategy-engine/tests/test_grid_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_indicators.py9
-rw-r--r--services/strategy-engine/tests/test_macd_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_moc_strategy.py7
-rw-r--r--services/strategy-engine/tests/test_multi_symbol.py10
-rw-r--r--services/strategy-engine/tests/test_plugin_loader.py2
-rw-r--r--services/strategy-engine/tests/test_rsi_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_stock_selector.py111
-rw-r--r--services/strategy-engine/tests/test_strategy_validation.py8
-rw-r--r--services/strategy-engine/tests/test_volume_profile_strategy.py15
-rw-r--r--services/strategy-engine/tests/test_vwap_strategy.py12
98 files changed, 2406 insertions, 297 deletions
diff --git a/services/api/Dockerfile b/services/api/Dockerfile
index b942075..93d2b75 100644
--- a/services/api/Dockerfile
+++ b/services/api/Dockerfile
@@ -1,11 +1,18 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/api/ services/api/
RUN pip install --no-cache-dir ./services/api
-COPY services/strategy-engine/strategies/ /app/strategies/
COPY services/strategy-engine/ services/strategy-engine/
RUN pip install --no-cache-dir ./services/strategy-engine
-ENV PYTHONPATH=/app
-CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000"]
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY services/strategy-engine/strategies/ /app/strategies/
+ENV PYTHONPATH=/app STRATEGIES_DIR=/app/strategies
+USER appuser
+CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000", "--timeout-graceful-shutdown", "30"]
diff --git a/services/api/pyproject.toml b/services/api/pyproject.toml
index fd2598d..95099d2 100644
--- a/services/api/pyproject.toml
+++ b/services/api/pyproject.toml
@@ -3,11 +3,7 @@ name = "trading-api"
version = "0.1.0"
description = "REST API for the trading platform"
requires-python = ">=3.12"
-dependencies = [
- "fastapi>=0.110",
- "uvicorn>=0.27",
- "trading-shared",
-]
+dependencies = ["fastapi>=0.110,<1", "uvicorn>=0.27,<1", "slowapi>=0.1.9,<1", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27"]
diff --git a/services/api/src/trading_api/dependencies/__init__.py b/services/api/src/trading_api/dependencies/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/api/src/trading_api/dependencies/__init__.py
diff --git a/services/api/src/trading_api/dependencies/auth.py b/services/api/src/trading_api/dependencies/auth.py
new file mode 100644
index 0000000..a5e76c1
--- /dev/null
+++ b/services/api/src/trading_api/dependencies/auth.py
@@ -0,0 +1,29 @@
+"""Bearer token authentication dependency."""
+
+import logging
+
+from fastapi import Depends, HTTPException, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+
+from shared.config import Settings
+
+logger = logging.getLogger(__name__)
+
+_security = HTTPBearer(auto_error=False)
+_settings = Settings()
+
+
+async def verify_token(
+ credentials: HTTPAuthorizationCredentials | None = Depends(_security),
+) -> None:
+ """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured."""
+ token = _settings.api_auth_token.get_secret_value()
+ if not token:
+ return # Auth disabled in dev mode
+
+ if credentials is None or credentials.credentials != token:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing authentication token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py
index 39f7b43..05c6d2f 100644
--- a/services/api/src/trading_api/main.py
+++ b/services/api/src/trading_api/main.py
@@ -1,33 +1,71 @@
"""Trading Platform REST API."""
+import logging
from contextlib import asynccontextmanager
-from fastapi import FastAPI
+from fastapi import Depends, FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+from slowapi.util import get_remote_address
from shared.config import Settings
from shared.db import Database
+from trading_api.dependencies.auth import verify_token
+from trading_api.routers import orders, portfolio, strategies
-from trading_api.routers import portfolio, orders, strategies
+logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
- app.state.db = Database(settings.database_url)
+ if not settings.api_auth_token.get_secret_value():
+ logger.warning("API_AUTH_TOKEN not set — authentication is disabled")
+ app.state.db = Database(settings.database_url.get_secret_value())
await app.state.db.connect()
yield
await app.state.db.close()
+cfg = Settings()
+
+limiter = Limiter(key_func=get_remote_address)
+
app = FastAPI(
title="Trading Platform API",
version="0.1.0",
lifespan=lifespan,
)
-app.include_router(portfolio.router, prefix="/api/v1/portfolio", tags=["portfolio"])
-app.include_router(orders.router, prefix="/api/v1/orders", tags=["orders"])
-app.include_router(strategies.router, prefix="/api/v1/strategies", tags=["strategies"])
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=cfg.cors_origins.split(","),
+ allow_methods=["GET", "POST"],
+ allow_headers=["Authorization", "Content-Type"],
+)
+
+app.include_router(
+ portfolio.router,
+ prefix="/api/v1/portfolio",
+ tags=["portfolio"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ orders.router,
+ prefix="/api/v1/orders",
+ tags=["orders"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ strategies.router,
+ prefix="/api/v1/strategies",
+ tags=["strategies"],
+ dependencies=[Depends(verify_token)],
+)
@app.get("/health")
diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py
index c69dc10..b664e2a 100644
--- a/services/api/src/trading_api/routers/orders.py
+++ b/services/api/src/trading_api/routers/orders.py
@@ -2,17 +2,23 @@
import logging
-from fastapi import APIRouter, HTTPException, Request
-from shared.sa_models import OrderRow, SignalRow
+from fastapi import APIRouter, HTTPException, Query, Request
+from slowapi import Limiter
+from slowapi.util import get_remote_address
from sqlalchemy import select
+from sqlalchemy.exc import OperationalError
+
+from shared.sa_models import OrderRow, SignalRow
logger = logging.getLogger(__name__)
router = APIRouter()
+limiter = Limiter(key_func=get_remote_address)
@router.get("/")
-async def get_orders(request: Request, limit: int = 50):
+@limiter.limit("60/minute")
+async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)):
"""Get recent orders."""
try:
db = request.app.state.db
@@ -35,13 +41,17 @@ async def get_orders(request: Request, limit: int = 50):
}
for r in rows
]
+ except OperationalError as exc:
+ logger.error("Database error fetching orders: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get orders: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve orders")
+ logger.error("Failed to get orders: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve orders") from exc
@router.get("/signals")
-async def get_signals(request: Request, limit: int = 50):
+@limiter.limit("60/minute")
+async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)):
"""Get recent signals."""
try:
db = request.app.state.db
@@ -62,6 +72,9 @@ async def get_signals(request: Request, limit: int = 50):
}
for r in rows
]
+ except OperationalError as exc:
+ logger.error("Database error fetching signals: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get signals: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve signals")
+ logger.error("Failed to get signals: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve signals") from exc
diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py
index d76d85d..56bee7c 100644
--- a/services/api/src/trading_api/routers/portfolio.py
+++ b/services/api/src/trading_api/routers/portfolio.py
@@ -2,9 +2,11 @@
import logging
-from fastapi import APIRouter, HTTPException, Request
-from shared.sa_models import PositionRow
+from fastapi import APIRouter, HTTPException, Query, Request
from sqlalchemy import select
+from sqlalchemy.exc import OperationalError
+
+from shared.sa_models import PositionRow
logger = logging.getLogger(__name__)
@@ -29,13 +31,16 @@ async def get_positions(request: Request):
}
for r in rows
]
+ except OperationalError as exc:
+ logger.error("Database error fetching positions: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get positions: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve positions")
+ logger.error("Failed to get positions: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve positions") from exc
@router.get("/snapshots")
-async def get_snapshots(request: Request, days: int = 30):
+async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)):
"""Get portfolio snapshots for the last N days."""
try:
db = request.app.state.db
@@ -49,6 +54,9 @@ async def get_snapshots(request: Request, days: int = 30):
}
for s in snapshots
]
+ except OperationalError as exc:
+ logger.error("Database error fetching snapshots: %s", exc)
+ raise HTTPException(status_code=503, detail="Database unavailable") from exc
except Exception as exc:
- logger.error("Failed to get snapshots: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to retrieve snapshots")
+ logger.error("Failed to get snapshots: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") from exc
diff --git a/services/api/src/trading_api/routers/strategies.py b/services/api/src/trading_api/routers/strategies.py
index 7ddd54e..157094c 100644
--- a/services/api/src/trading_api/routers/strategies.py
+++ b/services/api/src/trading_api/routers/strategies.py
@@ -42,6 +42,9 @@ async def list_strategies():
}
for s in strategies
]
+ except (ImportError, FileNotFoundError) as exc:
+ logger.error("Strategy loading error: %s", exc)
+ raise HTTPException(status_code=503, detail="Strategy engine unavailable") from exc
except Exception as exc:
- logger.error("Failed to list strategies: %s", exc)
- raise HTTPException(status_code=500, detail="Failed to list strategies")
+ logger.error("Failed to list strategies: %s", exc, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to list strategies") from exc
diff --git a/services/api/tests/test_api.py b/services/api/tests/test_api.py
index 669143b..f3b0a47 100644
--- a/services/api/tests/test_api.py
+++ b/services/api/tests/test_api.py
@@ -1,6 +1,7 @@
"""Tests for the REST API."""
from unittest.mock import AsyncMock, patch
+
from fastapi.testclient import TestClient
diff --git a/services/api/tests/test_orders_router.py b/services/api/tests/test_orders_router.py
index 0658619..52252c5 100644
--- a/services/api/tests/test_orders_router.py
+++ b/services/api/tests/test_orders_router.py
@@ -1,10 +1,10 @@
"""Tests for orders API router."""
-import pytest
from unittest.mock import AsyncMock, MagicMock
-from fastapi.testclient import TestClient
-from fastapi import FastAPI
+import pytest
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
from trading_api.routers.orders import router
diff --git a/services/api/tests/test_portfolio_router.py b/services/api/tests/test_portfolio_router.py
index f2584ea..8cd8ff8 100644
--- a/services/api/tests/test_portfolio_router.py
+++ b/services/api/tests/test_portfolio_router.py
@@ -1,11 +1,11 @@
"""Tests for portfolio API router."""
-import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
-from fastapi.testclient import TestClient
-from fastapi import FastAPI
+import pytest
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
from trading_api.routers.portfolio import router
@@ -45,7 +45,7 @@ def test_get_positions_with_data(app, mock_db):
app.state.db = db
mock_row = MagicMock()
- mock_row.symbol = "BTCUSDT"
+ mock_row.symbol = "AAPL"
mock_row.quantity = Decimal("0.1")
mock_row.avg_entry_price = Decimal("50000")
mock_row.current_price = Decimal("55000")
@@ -59,7 +59,7 @@ def test_get_positions_with_data(app, mock_db):
assert response.status_code == 200
data = response.json()
assert len(data) == 1
- assert data[0]["symbol"] == "BTCUSDT"
+ assert data[0]["symbol"] == "AAPL"
def test_get_snapshots_empty(app, mock_db):
diff --git a/services/backtester/Dockerfile b/services/backtester/Dockerfile
index 9a4f439..1108e42 100644
--- a/services/backtester/Dockerfile
+++ b/services/backtester/Dockerfile
@@ -1,10 +1,17 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/backtester/ services/backtester/
RUN pip install --no-cache-dir ./services/backtester
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
COPY services/strategy-engine/strategies/ /app/strategies/
ENV STRATEGIES_DIR=/app/strategies
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "backtester.main"]
diff --git a/services/backtester/pyproject.toml b/services/backtester/pyproject.toml
index 2601d04..034bcf6 100644
--- a/services/backtester/pyproject.toml
+++ b/services/backtester/pyproject.toml
@@ -3,7 +3,7 @@ name = "backtester"
version = "0.1.0"
description = "Strategy backtesting engine"
requires-python = ">=3.12"
-dependencies = ["pandas>=2.0", "numpy>=1.20", "rich>=13.0", "trading-shared"]
+dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "rich>=13.0,<14", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
diff --git a/services/backtester/src/backtester/config.py b/services/backtester/src/backtester/config.py
index f7897da..57ee1fb 100644
--- a/services/backtester/src/backtester/config.py
+++ b/services/backtester/src/backtester/config.py
@@ -5,7 +5,7 @@ from shared.config import Settings
class BacktestConfig(Settings):
backtest_initial_balance: float = 10000.0
- symbol: str = "BTCUSDT"
+ symbol: str = "AAPL"
timeframe: str = "1h"
strategy_name: str = "rsi_strategy"
candle_limit: int = 500
diff --git a/services/backtester/src/backtester/engine.py b/services/backtester/src/backtester/engine.py
index b03715d..fcf48f1 100644
--- a/services/backtester/src/backtester/engine.py
+++ b/services/backtester/src/backtester/engine.py
@@ -6,10 +6,9 @@ from dataclasses import dataclass, field
from decimal import Decimal
from typing import Protocol
-from shared.models import Candle, Signal
-
from backtester.metrics import DetailedMetrics, TradeRecord, compute_detailed_metrics
from backtester.simulator import OrderSimulator, SimulatedTrade
+from shared.models import Candle, Signal
class StrategyProtocol(Protocol):
@@ -101,7 +100,7 @@ class BacktestEngine:
final_balance = simulator.balance
if candles:
last_price = candles[-1].close
- for symbol, qty in simulator.positions.items():
+ for qty in simulator.positions.values():
if qty > Decimal("0"):
final_balance += qty * last_price
elif qty < Decimal("0"):
diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py
index a4cea76..dbde00b 100644
--- a/services/backtester/src/backtester/main.py
+++ b/services/backtester/src/backtester/main.py
@@ -17,11 +17,11 @@ _STRATEGIES_DIR = Path(
if _STRATEGIES_DIR.parent not in [Path(p) for p in sys.path]:
sys.path.insert(0, str(_STRATEGIES_DIR.parent))
-from shared.db import Database # noqa: E402
-from shared.models import Candle # noqa: E402
from backtester.config import BacktestConfig # noqa: E402
from backtester.engine import BacktestEngine # noqa: E402
from backtester.reporter import format_report # noqa: E402
+from shared.db import Database # noqa: E402
+from shared.models import Candle # noqa: E402
async def run_backtest() -> str:
@@ -45,7 +45,7 @@ async def run_backtest() -> str:
except Exception as exc:
raise RuntimeError(f"Failed to load strategy '{config.strategy_name}': {exc}") from exc
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
try:
rows = await db.get_candles(config.symbol, config.timeframe, config.candle_limit)
diff --git a/services/backtester/src/backtester/metrics.py b/services/backtester/src/backtester/metrics.py
index 239cb6f..c7b032b 100644
--- a/services/backtester/src/backtester/metrics.py
+++ b/services/backtester/src/backtester/metrics.py
@@ -266,7 +266,7 @@ def compute_detailed_metrics(
largest_win=largest_win,
largest_loss=largest_loss,
avg_holding_period=avg_holding,
- trade_pairs=[p for p in pairs],
+ trade_pairs=list(pairs),
risk_free_rate=risk_free_rate,
recovery_factor=recovery_factor,
max_consecutive_losses=max_consec_losses,
diff --git a/services/backtester/src/backtester/simulator.py b/services/backtester/src/backtester/simulator.py
index 64c88dd..6bce18b 100644
--- a/services/backtester/src/backtester/simulator.py
+++ b/services/backtester/src/backtester/simulator.py
@@ -1,9 +1,8 @@
"""Simulated order executor for backtesting."""
from dataclasses import dataclass, field
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from typing import Optional
from shared.models import OrderSide, Signal
@@ -16,7 +15,7 @@ class SimulatedTrade:
quantity: Decimal
balance_after: Decimal
fee: Decimal = Decimal("0")
- timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
@dataclass
@@ -27,8 +26,8 @@ class OpenPosition:
side: OrderSide # BUY = long, SELL = short
entry_price: Decimal
quantity: Decimal
- stop_loss: Optional[Decimal] = None
- take_profit: Optional[Decimal] = None
+ stop_loss: Decimal | None = None
+ take_profit: Decimal | None = None
class OrderSimulator:
@@ -70,7 +69,7 @@ class OrderSimulator:
remaining: list[OpenPosition] = []
for pos in self.open_positions:
triggered = False
- exit_price: Optional[Decimal] = None
+ exit_price: Decimal | None = None
if pos.side == OrderSide.BUY: # Long position
if pos.stop_loss is not None and candle_low <= pos.stop_loss:
@@ -125,12 +124,12 @@ class OrderSimulator:
def execute(
self,
signal: Signal,
- timestamp: Optional[datetime] = None,
- stop_loss: Optional[Decimal] = None,
- take_profit: Optional[Decimal] = None,
+ timestamp: datetime | None = None,
+ stop_loss: Decimal | None = None,
+ take_profit: Decimal | None = None,
) -> bool:
"""Execute a signal with slippage and fees. Returns True if accepted."""
- ts = timestamp or datetime.now(timezone.utc)
+ ts = timestamp or datetime.now(UTC)
exec_price = self._apply_slippage(signal.price, signal.side)
fee = self._calculate_fee(exec_price, signal.quantity)
diff --git a/services/backtester/src/backtester/walk_forward.py b/services/backtester/src/backtester/walk_forward.py
index c7b7fd8..720ad5e 100644
--- a/services/backtester/src/backtester/walk_forward.py
+++ b/services/backtester/src/backtester/walk_forward.py
@@ -1,11 +1,11 @@
"""Walk-forward analysis for strategy parameter optimization."""
+from collections.abc import Callable
from dataclasses import dataclass, field
from decimal import Decimal
-from typing import Callable
-from shared.models import Candle
from backtester.engine import BacktestEngine, BacktestResult, StrategyProtocol
+from shared.models import Candle
@dataclass
diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py
index 4794e63..f789831 100644
--- a/services/backtester/tests/test_engine.py
+++ b/services/backtester/tests/test_engine.py
@@ -1,20 +1,19 @@
"""Tests for the BacktestEngine."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from unittest.mock import MagicMock
-
-from shared.models import Candle, Signal, OrderSide
-
from backtester.engine import BacktestEngine
+from shared.models import Candle, OrderSide, Signal
+
def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle:
return Candle(
symbol=symbol,
timeframe=timeframe,
- open_time=datetime.now(timezone.utc),
+ open_time=datetime.now(UTC),
open=Decimal(str(price)),
high=Decimal(str(price * 1.01)),
low=Decimal(str(price * 0.99)),
diff --git a/services/backtester/tests/test_metrics.py b/services/backtester/tests/test_metrics.py
index 55f5b6c..13e545e 100644
--- a/services/backtester/tests/test_metrics.py
+++ b/services/backtester/tests/test_metrics.py
@@ -1,17 +1,16 @@
"""Tests for detailed backtest metrics."""
import math
-from datetime import datetime, timedelta, timezone
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
import pytest
-
from backtester.metrics import TradeRecord, compute_detailed_metrics
def _make_trade(side: str, price: str, minutes_offset: int = 0) -> TradeRecord:
return TradeRecord(
- time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(minutes=minutes_offset),
+ time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(minutes=minutes_offset),
symbol="AAPL",
side=side,
price=Decimal(price),
@@ -124,7 +123,7 @@ def test_consecutive_losses():
def test_risk_free_rate_affects_sharpe():
"""Higher risk-free rate should lower Sharpe ratio."""
- base = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ base = datetime(2025, 1, 1, tzinfo=UTC)
trades = [
TradeRecord(
time=base, symbol="AAPL", side="BUY", price=Decimal("100"), quantity=Decimal("1")
@@ -184,7 +183,7 @@ def test_daily_returns_populated():
def test_fee_subtracted_from_pnl():
"""Fees should be subtracted from trade PnL."""
- base = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ base = datetime(2025, 1, 1, tzinfo=UTC)
trades_with_fees = [
TradeRecord(
time=base,
diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py
index 62e2cdb..f85594f 100644
--- a/services/backtester/tests/test_simulator.py
+++ b/services/backtester/tests/test_simulator.py
@@ -1,11 +1,12 @@
"""Tests for the OrderSimulator."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from shared.models import OrderSide, Signal
from backtester.simulator import OrderSimulator
+from shared.models import OrderSide, Signal
+
def make_signal(
symbol: str,
@@ -135,7 +136,7 @@ def test_stop_loss_triggers():
signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1")
sim.execute(signal, stop_loss=Decimal("48000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("50500"),
candle_low=Decimal("47500"), # below stop_loss
@@ -153,7 +154,7 @@ def test_take_profit_triggers():
signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1")
sim.execute(signal, take_profit=Decimal("55000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("56000"), # above take_profit
candle_low=Decimal("50000"),
@@ -171,7 +172,7 @@ def test_stop_not_triggered_within_range():
signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1")
sim.execute(signal, stop_loss=Decimal("48000"), take_profit=Decimal("55000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("52000"),
candle_low=Decimal("49000"),
@@ -212,7 +213,7 @@ def test_short_stop_loss():
signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1")
sim.execute(signal, stop_loss=Decimal("52000"))
- ts = datetime(2025, 1, 1, tzinfo=timezone.utc)
+ ts = datetime(2025, 1, 1, tzinfo=UTC)
closed = sim.check_stops(
candle_high=Decimal("53000"), # above stop_loss
candle_low=Decimal("49000"),
diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py
index 5ab2e7b..b1aa12c 100644
--- a/services/backtester/tests/test_walk_forward.py
+++ b/services/backtester/tests/test_walk_forward.py
@@ -1,18 +1,18 @@
"""Tests for walk-forward analysis."""
import sys
-from pathlib import Path
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
-from datetime import datetime, timedelta, timezone
-
+from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine"))
-from shared.models import Candle
from backtester.walk_forward import WalkForwardEngine, WalkForwardResult
from strategies.rsi_strategy import RsiStrategy
+from shared.models import Candle
+
def _generate_candles(n=100, base_price=100.0):
candles = []
@@ -21,9 +21,9 @@ def _generate_candles(n=100, base_price=100.0):
price = base_price + (i % 20) - 10
candles.append(
Candle(
- symbol="BTCUSDT",
+ symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i),
open=Decimal(str(price)),
high=Decimal(str(price + 5)),
low=Decimal(str(price - 5)),
diff --git a/services/data-collector/Dockerfile b/services/data-collector/Dockerfile
index 8cb8af4..4d154c5 100644
--- a/services/data-collector/Dockerfile
+++ b/services/data-collector/Dockerfile
@@ -1,8 +1,15 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/data-collector/ services/data-collector/
RUN pip install --no-cache-dir ./services/data-collector
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "data_collector.main"]
diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py
index b42b34c..2d44848 100644
--- a/services/data-collector/src/data_collector/main.py
+++ b/services/data-collector/src/data_collector/main.py
@@ -2,6 +2,9 @@
import asyncio
+import aiohttp
+
+from data_collector.config import CollectorConfig
from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
from shared.db import Database
@@ -11,8 +14,7 @@ from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.models import Candle
from shared.notifier import TelegramNotifier
-
-from data_collector.config import CollectorConfig
+from shared.shutdown import GracefulShutdown
# Health check port: base + 0
HEALTH_PORT_OFFSET = 0
@@ -45,8 +47,10 @@ async def fetch_latest_bars(
volume=Decimal(str(bar["v"])),
)
candles.append(candle)
- except Exception as exc:
- log.warning("fetch_bar_failed", symbol=symbol, error=str(exc))
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("fetch_bar_network_error", symbol=symbol, error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("fetch_bar_parse_error", symbol=symbol, error=str(exc))
return candles
@@ -56,18 +60,18 @@ async def run() -> None:
metrics = ServiceMetrics("data_collector")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
alpaca = AlpacaClient(
- api_key=config.alpaca_api_key,
- api_secret=config.alpaca_api_secret,
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
paper=config.alpaca_paper,
)
@@ -83,14 +87,17 @@ async def run() -> None:
symbols = config.symbols
timeframe = config.timeframes[0] if config.timeframes else "1Day"
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("starting", symbols=symbols, timeframe=timeframe, poll_interval=poll_interval)
try:
- while True:
+ while not shutdown.is_shutting_down:
# Check if market is open
try:
is_open = await alpaca.is_market_open()
- except Exception:
+ except (aiohttp.ClientError, ConnectionError, TimeoutError):
is_open = False
if is_open:
@@ -109,7 +116,7 @@ async def run() -> None:
await asyncio.sleep(poll_interval)
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "data-collector")
raise
finally:
diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py
index be85578..51f3aee 100644
--- a/services/data-collector/tests/test_storage.py
+++ b/services/data-collector/tests/test_storage.py
@@ -1,19 +1,20 @@
"""Tests for storage module."""
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
-from shared.models import Candle
+import pytest
from data_collector.storage import CandleStorage
+from shared.models import Candle
+
-def _make_candle(symbol: str = "BTCUSDT") -> Candle:
+def _make_candle(symbol: str = "AAPL") -> Candle:
return Candle(
symbol=symbol,
timeframe="1m",
- open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC),
open=Decimal("30000"),
high=Decimal("30100"),
low=Decimal("29900"),
@@ -39,11 +40,11 @@ async def test_storage_saves_to_db_and_publishes():
mock_broker.publish.assert_called_once()
stream_arg = mock_broker.publish.call_args[0][0]
- assert stream_arg == "candles.BTCUSDT"
+ assert stream_arg == "candles.AAPL"
data_arg = mock_broker.publish.call_args[0][1]
assert data_arg["type"] == "CANDLE"
- assert data_arg["data"]["symbol"] == "BTCUSDT"
+ assert data_arg["data"]["symbol"] == "AAPL"
@pytest.mark.asyncio
diff --git a/services/news-collector/Dockerfile b/services/news-collector/Dockerfile
new file mode 100644
index 0000000..7accee2
--- /dev/null
+++ b/services/news-collector/Dockerfile
@@ -0,0 +1,17 @@
+FROM python:3.12-slim AS builder
+WORKDIR /app
+COPY shared/ shared/
+RUN pip install --no-cache-dir ./shared
+COPY services/news-collector/ services/news-collector/
+RUN pip install --no-cache-dir ./services/news-collector
+RUN python -c "import nltk; nltk.download('vader_lexicon', download_dir='/usr/local/nltk_data')"
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+COPY --from=builder /usr/local/nltk_data /usr/local/nltk_data
+ENV PYTHONPATH=/app
+USER appuser
+CMD ["python", "-m", "news_collector.main"]
diff --git a/services/news-collector/pyproject.toml b/services/news-collector/pyproject.toml
new file mode 100644
index 0000000..6e62b70
--- /dev/null
+++ b/services/news-collector/pyproject.toml
@@ -0,0 +1,20 @@
+[project]
+name = "news-collector"
+version = "0.1.0"
+description = "News and sentiment data collector service"
+requires-python = ">=3.12"
+dependencies = ["trading-shared", "feedparser>=6.0,<7", "nltk>=3.8,<4", "aiohttp>=3.9,<4"]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=8.0",
+ "pytest-asyncio>=0.23",
+ "aioresponses>=0.7",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/news_collector"]
diff --git a/services/news-collector/src/news_collector/__init__.py b/services/news-collector/src/news_collector/__init__.py
new file mode 100644
index 0000000..5547af2
--- /dev/null
+++ b/services/news-collector/src/news_collector/__init__.py
@@ -0,0 +1 @@
+"""News collector service."""
diff --git a/services/news-collector/src/news_collector/collectors/__init__.py b/services/news-collector/src/news_collector/collectors/__init__.py
new file mode 100644
index 0000000..5ef36a7
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/__init__.py
@@ -0,0 +1 @@
+"""News collectors."""
diff --git a/services/news-collector/src/news_collector/collectors/base.py b/services/news-collector/src/news_collector/collectors/base.py
new file mode 100644
index 0000000..bb43fd6
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/base.py
@@ -0,0 +1,18 @@
+"""Base class for all news collectors."""
+
+from abc import ABC, abstractmethod
+
+from shared.models import NewsItem
+
+
+class BaseCollector(ABC):
+ name: str = "base"
+ poll_interval: int = 300 # seconds
+
+ @abstractmethod
+ async def collect(self) -> list[NewsItem]:
+ """Collect news items from the source."""
+
+ @abstractmethod
+ async def is_available(self) -> bool:
+ """Check if this data source is accessible."""
diff --git a/services/news-collector/src/news_collector/collectors/fear_greed.py b/services/news-collector/src/news_collector/collectors/fear_greed.py
new file mode 100644
index 0000000..42e8f88
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/fear_greed.py
@@ -0,0 +1,62 @@
+"""CNN Fear & Greed Index collector."""
+
+import logging
+from dataclasses import dataclass
+
+import aiohttp
+
+from news_collector.collectors.base import BaseCollector
+
+logger = logging.getLogger(__name__)
+
+FEAR_GREED_URL = "https://production.dataviz.cnn.io/index/fearandgreed/graphdata"
+
+
+@dataclass
+class FearGreedResult:
+ fear_greed: int
+ fear_greed_label: str
+
+
+class FearGreedCollector(BaseCollector):
+ name = "fear_greed"
+ poll_interval = 3600 # 1 hour
+
+ async def is_available(self) -> bool:
+ return True
+
+ async def _fetch_index(self) -> dict | None:
+ headers = {"User-Agent": "Mozilla/5.0"}
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ FEAR_GREED_URL, headers=headers, timeout=aiohttp.ClientTimeout(total=10)
+ ) as resp:
+ if resp.status != 200:
+ return None
+ return await resp.json()
+ except Exception:
+ return None
+
+ def _classify(self, score: int) -> str:
+ if score <= 20:
+ return "Extreme Fear"
+ if score <= 40:
+ return "Fear"
+ if score <= 60:
+ return "Neutral"
+ if score <= 80:
+ return "Greed"
+ return "Extreme Greed"
+
+ async def collect(self) -> FearGreedResult | None:
+ data = await self._fetch_index()
+ if data is None:
+ return None
+ try:
+ fg = data["fear_and_greed"]
+ score = int(fg["score"])
+ label = fg.get("rating", self._classify(score))
+ return FearGreedResult(fear_greed=score, fear_greed_label=label)
+ except (KeyError, ValueError, TypeError):
+ return None
diff --git a/services/news-collector/src/news_collector/collectors/fed.py b/services/news-collector/src/news_collector/collectors/fed.py
new file mode 100644
index 0000000..52128e5
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/fed.py
@@ -0,0 +1,119 @@
+"""Federal Reserve RSS collector with hawkish/dovish/neutral stance detection."""
+
+import asyncio
+import logging
+from calendar import timegm
+from datetime import UTC, datetime
+
+import feedparser
+from nltk.sentiment.vader import SentimentIntensityAnalyzer
+
+from shared.models import NewsCategory, NewsItem
+
+from .base import BaseCollector
+
+logger = logging.getLogger(__name__)
+
+_FED_RSS_URL = "https://www.federalreserve.gov/feeds/press_all.xml"
+
+_HAWKISH_KEYWORDS = [
+ "rate hike",
+ "interest rate increase",
+ "tighten",
+ "tightening",
+ "inflation",
+ "hawkish",
+ "restrictive",
+ "raise rates",
+ "hike rates",
+]
+_DOVISH_KEYWORDS = [
+ "rate cut",
+ "interest rate decrease",
+ "easing",
+ "ease",
+ "stimulus",
+ "dovish",
+ "accommodative",
+ "lower rates",
+ "cut rates",
+ "quantitative easing",
+]
+
+
+def _detect_stance(text: str) -> str:
+ lower = text.lower()
+ hawkish_hits = sum(1 for kw in _HAWKISH_KEYWORDS if kw in lower)
+ dovish_hits = sum(1 for kw in _DOVISH_KEYWORDS if kw in lower)
+ if hawkish_hits > dovish_hits:
+ return "hawkish"
+ if dovish_hits > hawkish_hits:
+ return "dovish"
+ return "neutral"
+
+
+class FedCollector(BaseCollector):
+ name: str = "fed"
+ poll_interval: int = 3600
+
+ def __init__(self) -> None:
+ self._vader = SentimentIntensityAnalyzer()
+
+ async def is_available(self) -> bool:
+ return True
+
+ async def _fetch_fed_rss(self) -> list[dict]:
+ loop = asyncio.get_event_loop()
+ try:
+ parsed = await loop.run_in_executor(None, feedparser.parse, _FED_RSS_URL)
+ return parsed.get("entries", [])
+ except Exception as exc:
+ logger.error("Fed RSS fetch failed: %s", exc)
+ return []
+
+ def _parse_published(self, entry: dict) -> datetime:
+ published_parsed = entry.get("published_parsed")
+ if published_parsed:
+ try:
+ ts = timegm(published_parsed)
+ return datetime.fromtimestamp(ts, tz=UTC)
+ except Exception:
+ pass
+ return datetime.now(UTC)
+
+ async def collect(self) -> list[NewsItem]:
+ try:
+ entries = await self._fetch_fed_rss()
+ except Exception as exc:
+ logger.error("Fed collector error: %s", exc)
+ return []
+
+ items: list[NewsItem] = []
+
+ for entry in entries:
+ title = entry.get("title", "").strip()
+ if not title:
+ continue
+
+ summary = entry.get("summary", "") or ""
+ combined = f"{title} {summary}"
+
+ sentiment = self._vader.polarity_scores(combined)["compound"]
+ stance = _detect_stance(combined)
+ published_at = self._parse_published(entry)
+
+ items.append(
+ NewsItem(
+ source=self.name,
+ headline=title,
+ summary=summary or None,
+ url=entry.get("link") or None,
+ published_at=published_at,
+ symbols=[],
+ sentiment=sentiment,
+ category=NewsCategory.FED,
+ raw_data={"stance": stance, **dict(entry)},
+ )
+ )
+
+ return items
diff --git a/services/news-collector/src/news_collector/collectors/finnhub.py b/services/news-collector/src/news_collector/collectors/finnhub.py
new file mode 100644
index 0000000..67cb455
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/finnhub.py
@@ -0,0 +1,88 @@
+"""Finnhub news collector with VADER sentiment analysis."""
+
+import logging
+from datetime import UTC, datetime
+
+import aiohttp
+from nltk.sentiment.vader import SentimentIntensityAnalyzer
+
+from shared.models import NewsCategory, NewsItem
+
+from .base import BaseCollector
+
+logger = logging.getLogger(__name__)
+
+_CATEGORY_KEYWORDS: dict[NewsCategory, list[str]] = {
+ NewsCategory.FED: ["fed", "fomc", "rate", "federal reserve"],
+ NewsCategory.POLICY: ["tariff", "trump", "regulation", "policy", "trade war"],
+ NewsCategory.EARNINGS: ["earnings", "revenue", "profit", "eps", "guidance", "quarter"],
+}
+
+
+def _categorize(text: str) -> NewsCategory:
+ lower = text.lower()
+ for category, keywords in _CATEGORY_KEYWORDS.items():
+ if any(kw in lower for kw in keywords):
+ return category
+ return NewsCategory.MACRO
+
+
+class FinnhubCollector(BaseCollector):
+ name: str = "finnhub"
+ poll_interval: int = 300
+
+ _BASE_URL = "https://finnhub.io/api/v1/news"
+
+ def __init__(self, api_key: str) -> None:
+ self._api_key = api_key
+ self._vader = SentimentIntensityAnalyzer()
+
+ async def is_available(self) -> bool:
+ return bool(self._api_key)
+
+ async def _fetch_news(self) -> list[dict]:
+ url = f"{self._BASE_URL}?category=general&token={self._api_key}"
+ async with aiohttp.ClientSession() as session:
+ async with session.get(url) as resp:
+ resp.raise_for_status()
+ return await resp.json()
+
+ async def collect(self) -> list[NewsItem]:
+ try:
+ raw_items = await self._fetch_news()
+ except Exception as exc:
+ logger.error("Finnhub fetch failed: %s", exc)
+ return []
+
+ items: list[NewsItem] = []
+ for article in raw_items:
+ headline = article.get("headline", "")
+ summary = article.get("summary", "")
+ combined = f"{headline} {summary}"
+
+ sentiment_scores = self._vader.polarity_scores(combined)
+ sentiment = sentiment_scores["compound"]
+
+ ts = article.get("datetime", 0)
+ published_at = datetime.fromtimestamp(ts, tz=UTC)
+
+ related = article.get("related", "")
+ symbols = [t.strip() for t in related.split(",") if t.strip()] if related else []
+
+ category = _categorize(combined)
+
+ items.append(
+ NewsItem(
+ source=self.name,
+ headline=headline,
+ summary=summary or None,
+ url=article.get("url") or None,
+ published_at=published_at,
+ symbols=symbols,
+ sentiment=sentiment,
+ category=category,
+ raw_data=article,
+ )
+ )
+
+ return items
diff --git a/services/news-collector/src/news_collector/collectors/reddit.py b/services/news-collector/src/news_collector/collectors/reddit.py
new file mode 100644
index 0000000..4e9d6f5
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/reddit.py
@@ -0,0 +1,97 @@
+"""Reddit social sentiment collector using JSON API with VADER sentiment analysis."""
+
+import logging
+import re
+from datetime import UTC, datetime
+
+import aiohttp
+from nltk.sentiment.vader import SentimentIntensityAnalyzer
+
+from shared.models import NewsCategory, NewsItem
+
+from .base import BaseCollector
+
+logger = logging.getLogger(__name__)
+
+_SUBREDDITS = ["wallstreetbets", "stocks", "investing"]
+_MIN_SCORE = 50
+
+_TICKER_PATTERN = re.compile(
+ r"\b(AAPL|MSFT|GOOGL|GOOG|AMZN|TSLA|NVDA|META|BRK\.?[AB]|JPM|V|UNH|XOM|"
+ r"JNJ|WMT|MA|PG|HD|CVX|MRK|LLY|ABBV|PFE|BAC|KO|AVGO|COST|MCD|TMO|"
+ r"CSCO|ACN|ABT|DHR|TXN|NEE|NFLX|PM|UPS|RTX|HON|QCOM|AMGN|LOW|IBM|"
+ r"INTC|AMD|PYPL|GS|MS|BLK|SPGI|CAT|DE|GE|MMM|BA|F|GM|DIS|CMCSA)\b"
+)
+
+
+class RedditCollector(BaseCollector):
+ name: str = "reddit"
+ poll_interval: int = 900
+
+ def __init__(self) -> None:
+ self._vader = SentimentIntensityAnalyzer()
+
+ async def is_available(self) -> bool:
+ return True
+
+ async def _fetch_subreddit(self, subreddit: str) -> list[dict]:
+ url = f"https://www.reddit.com/r/{subreddit}/hot.json?limit=25"
+ headers = {"User-Agent": "TradingPlatform/1.0 (research@example.com)"}
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ url, headers=headers, timeout=aiohttp.ClientTimeout(total=10)
+ ) as resp:
+ if resp.status == 200:
+ data = await resp.json()
+ return data.get("data", {}).get("children", [])
+ except Exception as exc:
+ logger.error("Reddit fetch failed for r/%s: %s", subreddit, exc)
+ return []
+
+ async def collect(self) -> list[NewsItem]:
+ seen_titles: set[str] = set()
+ items: list[NewsItem] = []
+
+ for subreddit in _SUBREDDITS:
+ try:
+ posts = await self._fetch_subreddit(subreddit)
+ except Exception as exc:
+ logger.error("Reddit collector error for r/%s: %s", subreddit, exc)
+ continue
+
+ for post in posts:
+ post_data = post.get("data", {})
+ title = post_data.get("title", "").strip()
+ score = post_data.get("score", 0)
+
+ if not title or score < _MIN_SCORE:
+ continue
+ if title in seen_titles:
+ continue
+ seen_titles.add(title)
+
+ selftext = post_data.get("selftext", "") or ""
+ combined = f"{title} {selftext}"
+
+ sentiment = self._vader.polarity_scores(combined)["compound"]
+ symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined)))
+
+ created_utc = post_data.get("created_utc", 0)
+ published_at = datetime.fromtimestamp(created_utc, tz=UTC)
+
+ items.append(
+ NewsItem(
+ source=self.name,
+ headline=title,
+ summary=selftext or None,
+ url=post_data.get("url") or None,
+ published_at=published_at,
+ symbols=symbols,
+ sentiment=sentiment,
+ category=NewsCategory.SOCIAL,
+ raw_data=post_data,
+ )
+ )
+
+ return items
diff --git a/services/news-collector/src/news_collector/collectors/rss.py b/services/news-collector/src/news_collector/collectors/rss.py
new file mode 100644
index 0000000..bca0e9f
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/rss.py
@@ -0,0 +1,105 @@
+"""RSS news collector using feedparser with VADER sentiment analysis."""
+
+import asyncio
+import logging
+import re
+from datetime import UTC, datetime
+from time import mktime
+
+import feedparser
+from nltk.sentiment.vader import SentimentIntensityAnalyzer
+
+from shared.models import NewsCategory, NewsItem
+
+from .base import BaseCollector
+
+logger = logging.getLogger(__name__)
+
+_DEFAULT_FEEDS = [
+ "https://finance.yahoo.com/news/rssindex",
+ "https://news.google.com/rss/search?q=stock+market+finance&hl=en-US&gl=US&ceid=US:en",
+ "https://feeds.marketwatch.com/marketwatch/topstories/",
+]
+
+_TICKER_PATTERN = re.compile(
+ r"\b(AAPL|MSFT|GOOGL|GOOG|AMZN|TSLA|NVDA|META|BRK\.?[AB]|JPM|V|UNH|XOM|"
+ r"JNJ|WMT|MA|PG|HD|CVX|MRK|LLY|ABBV|PFE|BAC|KO|AVGO|COST|MCD|TMO|"
+ r"CSCO|ACN|ABT|DHR|TXN|NEE|NFLX|PM|UPS|RTX|HON|QCOM|AMGN|LOW|IBM|"
+ r"INTC|AMD|PYPL|GS|MS|BLK|SPGI|CAT|DE|GE|MMM|BA|F|GM|DIS|CMCSA)\b"
+)
+
+
+class RSSCollector(BaseCollector):
+ name: str = "rss"
+ poll_interval: int = 600
+
+ def __init__(self, feeds: list[str] | None = None) -> None:
+ self._feeds = feeds if feeds is not None else _DEFAULT_FEEDS
+ self._vader = SentimentIntensityAnalyzer()
+
+ async def is_available(self) -> bool:
+ return True
+
+ async def _fetch_feeds(self) -> list[dict]:
+ loop = asyncio.get_event_loop()
+ results = []
+ for url in self._feeds:
+ try:
+ parsed = await loop.run_in_executor(None, feedparser.parse, url)
+ results.append(parsed)
+ except Exception as exc:
+ logger.error("RSS fetch failed for %s: %s", url, exc)
+ return results
+
+ def _parse_published(self, entry: dict) -> datetime:
+ parsed_time = entry.get("published_parsed")
+ if parsed_time:
+ try:
+ ts = mktime(parsed_time)
+ return datetime.fromtimestamp(ts, tz=UTC)
+ except Exception:
+ pass
+ return datetime.now(UTC)
+
+ async def collect(self) -> list[NewsItem]:
+ try:
+ feeds = await self._fetch_feeds()
+ except Exception as exc:
+ logger.error("RSS collector error: %s", exc)
+ return []
+
+ seen_titles: set[str] = set()
+ items: list[NewsItem] = []
+
+ for feed in feeds:
+ for entry in feed.get("entries", []):
+ title = entry.get("title", "").strip()
+ if not title or title in seen_titles:
+ continue
+ seen_titles.add(title)
+
+ summary = entry.get("summary", "") or ""
+ combined = f"{title} {summary}"
+
+ sentiment_scores = self._vader.polarity_scores(combined)
+ sentiment = sentiment_scores["compound"]
+
+ symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined)))
+
+ published_at = self._parse_published(entry)
+
+ items.append(
+ NewsItem(
+ source=self.name,
+ headline=title,
+ summary=summary or None,
+ url=entry.get("link") or None,
+ published_at=published_at,
+ symbols=symbols,
+ sentiment=sentiment,
+ category=NewsCategory.MACRO,
+ raw_data=dict(entry),
+ )
+ )
+
+ return items
diff --git a/services/news-collector/src/news_collector/collectors/sec_edgar.py b/services/news-collector/src/news_collector/collectors/sec_edgar.py
new file mode 100644
index 0000000..d88518f
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/sec_edgar.py
@@ -0,0 +1,98 @@
+"""SEC EDGAR filing collector (free, no API key required)."""
+
+import logging
+from datetime import UTC, datetime
+
+import aiohttp
+from nltk.sentiment.vader import SentimentIntensityAnalyzer
+
+from news_collector.collectors.base import BaseCollector
+from shared.models import NewsCategory, NewsItem
+
+logger = logging.getLogger(__name__)
+
+TRACKED_CIKS = {
+ "0000320193": "AAPL",
+ "0000789019": "MSFT",
+ "0001652044": "GOOGL",
+ "0001018724": "AMZN",
+ "0001318605": "TSLA",
+ "0001045810": "NVDA",
+ "0001326801": "META",
+ "0000019617": "JPM",
+ "0000078003": "PFE",
+ "0000021344": "KO",
+}
+
+SEC_USER_AGENT = "TradingPlatform research@example.com"
+
+
+class SecEdgarCollector(BaseCollector):
+ name = "sec_edgar"
+ poll_interval = 1800 # 30 minutes
+
+ def __init__(self) -> None:
+ self._vader = SentimentIntensityAnalyzer()
+
+ async def is_available(self) -> bool:
+ return True
+
+ async def _fetch_recent_filings(self) -> list[dict]:
+ results = []
+ headers = {"User-Agent": SEC_USER_AGENT}
+ async with aiohttp.ClientSession() as session:
+ for cik, ticker in TRACKED_CIKS.items():
+ try:
+ url = f"https://data.sec.gov/submissions/CIK{cik}.json"
+ async with session.get(
+ url, headers=headers, timeout=aiohttp.ClientTimeout(total=10)
+ ) as resp:
+ if resp.status == 200:
+ data = await resp.json()
+ data["tickers"] = [{"ticker": ticker}]
+ results.append(data)
+ except Exception as exc:
+ logger.warning("sec_fetch_failed", cik=cik, error=str(exc))
+ return results
+
+ async def collect(self) -> list[NewsItem]:
+ filings_data = await self._fetch_recent_filings()
+ items = []
+ today = datetime.now(UTC).strftime("%Y-%m-%d")
+
+ for company_data in filings_data:
+ tickers = [t["ticker"] for t in company_data.get("tickers", [])]
+ company_name = company_data.get("name", "Unknown")
+ recent = company_data.get("filings", {}).get("recent", {})
+
+ forms = recent.get("form", [])
+ dates = recent.get("filingDate", [])
+ descriptions = recent.get("primaryDocDescription", [])
+ accessions = recent.get("accessionNumber", [])
+
+ for i, form in enumerate(forms):
+ if form != "8-K":
+ continue
+ filing_date = dates[i] if i < len(dates) else ""
+ if filing_date != today:
+ continue
+
+ desc = descriptions[i] if i < len(descriptions) else "8-K Filing"
+ accession = accessions[i] if i < len(accessions) else ""
+ headline = f"{company_name} ({', '.join(tickers)}): {form} - {desc}"
+
+ items.append(
+ NewsItem(
+ source=self.name,
+ headline=headline,
+ summary=desc,
+ url=f"https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany&accession={accession}",
+ published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace(tzinfo=UTC),
+ symbols=tickers,
+ sentiment=self._vader.polarity_scores(headline)["compound"],
+ category=NewsCategory.FILING,
+ raw_data={"form": form, "accession": accession},
+ )
+ )
+
+ return items
diff --git a/services/news-collector/src/news_collector/collectors/truth_social.py b/services/news-collector/src/news_collector/collectors/truth_social.py
new file mode 100644
index 0000000..e2acd88
--- /dev/null
+++ b/services/news-collector/src/news_collector/collectors/truth_social.py
@@ -0,0 +1,86 @@
+"""Truth Social collector using Mastodon-compatible API with VADER sentiment analysis."""
+
+import logging
+import re
+from datetime import UTC, datetime
+
+import aiohttp
+from nltk.sentiment.vader import SentimentIntensityAnalyzer
+
+from shared.models import NewsCategory, NewsItem
+
+from .base import BaseCollector
+
+logger = logging.getLogger(__name__)
+
+_TRUMP_ACCOUNT_ID = "107780257626128497"
+_API_URL = f"https://truthsocial.com/api/v1/accounts/{_TRUMP_ACCOUNT_ID}/statuses"
+
+_HTML_TAG_PATTERN = re.compile(r"<[^>]+>")
+
+
+def _strip_html(text: str) -> str:
+ return _HTML_TAG_PATTERN.sub("", text).strip()
+
+
+class TruthSocialCollector(BaseCollector):
+ name: str = "truth_social"
+ poll_interval: int = 900
+
+ def __init__(self) -> None:
+ self._vader = SentimentIntensityAnalyzer()
+
+ async def is_available(self) -> bool:
+ return True
+
+ async def _fetch_posts(self) -> list[dict]:
+ headers = {"User-Agent": "TradingPlatform/1.0 (research@example.com)"}
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ _API_URL, headers=headers, timeout=aiohttp.ClientTimeout(total=10)
+ ) as resp:
+ if resp.status == 200:
+ return await resp.json()
+ except Exception as exc:
+ logger.error("Truth Social fetch failed: %s", exc)
+ return []
+
+ async def collect(self) -> list[NewsItem]:
+ try:
+ posts = await self._fetch_posts()
+ except Exception as exc:
+ logger.error("Truth Social collector error: %s", exc)
+ return []
+
+ items: list[NewsItem] = []
+
+ for post in posts:
+ raw_content = post.get("content", "") or ""
+ content = _strip_html(raw_content)
+ if not content:
+ continue
+
+ sentiment = self._vader.polarity_scores(content)["compound"]
+
+ created_at_str = post.get("created_at", "")
+ try:
+ published_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
+ except Exception:
+ published_at = datetime.now(UTC)
+
+ items.append(
+ NewsItem(
+ source=self.name,
+ headline=content[:200],
+ summary=content if len(content) > 200 else None,
+ url=post.get("url") or None,
+ published_at=published_at,
+ symbols=[],
+ sentiment=sentiment,
+ category=NewsCategory.POLICY,
+ raw_data=post,
+ )
+ )
+
+ return items
diff --git a/services/news-collector/src/news_collector/config.py b/services/news-collector/src/news_collector/config.py
new file mode 100644
index 0000000..6e78eba
--- /dev/null
+++ b/services/news-collector/src/news_collector/config.py
@@ -0,0 +1,7 @@
+"""News Collector configuration."""
+
+from shared.config import Settings
+
+
+class NewsCollectorConfig(Settings):
+ health_port: int = 8084
diff --git a/services/news-collector/src/news_collector/main.py b/services/news-collector/src/news_collector/main.py
new file mode 100644
index 0000000..c39fa67
--- /dev/null
+++ b/services/news-collector/src/news_collector/main.py
@@ -0,0 +1,204 @@
+"""News Collector Service — fetches news from multiple sources and aggregates sentiment."""
+
+import asyncio
+from datetime import UTC, datetime
+
+import aiohttp
+
+from news_collector.collectors.fear_greed import FearGreedCollector
+from news_collector.collectors.fed import FedCollector
+from news_collector.collectors.finnhub import FinnhubCollector
+from news_collector.collectors.reddit import RedditCollector
+from news_collector.collectors.rss import RSSCollector
+from news_collector.collectors.sec_edgar import SecEdgarCollector
+from news_collector.collectors.truth_social import TruthSocialCollector
+from news_collector.config import NewsCollectorConfig
+from shared.broker import RedisBroker
+from shared.db import Database
+from shared.events import NewsEvent
+from shared.healthcheck import HealthCheckServer
+from shared.logging import setup_logging
+from shared.metrics import ServiceMetrics
+from shared.models import NewsItem
+from shared.notifier import TelegramNotifier
+from shared.sentiment import SentimentAggregator
+from shared.sentiment_models import MarketSentiment
+from shared.shutdown import GracefulShutdown
+
+
+async def run_collector_once(collector, db: Database, broker: RedisBroker) -> int:
+ """Run a single collector, store results in DB, publish to Redis.
+
+ Returns the number of items collected.
+ """
+ items: list[NewsItem] = await collector.collect()
+ count = 0
+ for item in items:
+ await db.insert_news_item(item)
+ event = NewsEvent(data=item)
+ stream = f"news.{item.category.value}"
+ await broker.publish(stream, event.to_dict())
+ count += 1
+ return count
+
+
+async def run_collector_loop(collector, db: Database, broker: RedisBroker, log) -> None:
+ """Run a collector repeatedly on its configured poll_interval."""
+ while True:
+ try:
+ count = await run_collector_once(collector, db, broker)
+ log.info(
+ "collector_ran",
+ collector=collector.name,
+ count=count,
+ )
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning(
+ "collector_network_error",
+ collector=collector.name,
+ error=str(exc),
+ )
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning(
+ "collector_parse_error",
+ collector=collector.name,
+ error=str(exc),
+ )
+ await asyncio.sleep(collector.poll_interval)
+
+
+async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log) -> None:
+ """Fetch Fear & Greed index on its interval and update MarketSentiment in DB."""
+ while True:
+ try:
+ result = await collector.collect()
+ if result is not None:
+ ms = MarketSentiment(
+ fear_greed=result.fear_greed,
+ fear_greed_label=result.fear_greed_label,
+ vix=None,
+ fed_stance="neutral",
+ market_regime=_determine_regime(result.fear_greed, None),
+ updated_at=datetime.now(UTC),
+ )
+ await db.upsert_market_sentiment(ms)
+ log.info(
+ "fear_greed_updated",
+ value=result.fear_greed,
+ label=result.fear_greed_label,
+ )
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("fear_greed_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("fear_greed_parse_error", error=str(exc))
+ await asyncio.sleep(collector.poll_interval)
+
+
+async def run_aggregator_loop(db: Database, interval: int, log) -> None:
+ """Run SentimentAggregator every interval seconds and persist scores."""
+ aggregator = SentimentAggregator()
+ while True:
+ await asyncio.sleep(interval)
+ try:
+ now = datetime.now(UTC)
+ news_items = await db.get_recent_news(hours=24)
+ scores = aggregator.aggregate(news_items, now)
+ for score in scores.values():
+ await db.upsert_symbol_score(score)
+ log.info("aggregation_complete", symbols=len(scores))
+ except (ConnectionError, TimeoutError) as exc:
+ log.warning("aggregator_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("aggregator_parse_error", error=str(exc))
+
+
+def _determine_regime(fear_greed: int, vix: float | None) -> str:
+ """Classify market regime from fear/greed index and optional VIX."""
+ aggregator = SentimentAggregator()
+ return aggregator.determine_regime(fear_greed, vix)
+
+
+async def run() -> None:
+ config = NewsCollectorConfig()
+ log = setup_logging("news-collector", config.log_level, config.log_format)
+ metrics = ServiceMetrics("news_collector")
+
+ notifier = TelegramNotifier(
+ bot_token=config.telegram_bot_token.get_secret_value(),
+ chat_id=config.telegram_chat_id,
+ )
+
+ db = Database(config.database_url.get_secret_value())
+ await db.connect()
+
+ broker = RedisBroker(config.redis_url.get_secret_value())
+
+ health = HealthCheckServer(
+ "news-collector",
+ port=config.health_port,
+ auth_token=config.metrics_auth_token,
+ )
+ await health.start()
+ metrics.service_up.labels(service="news-collector").set(1)
+
+ # Build collectors
+ finnhub = FinnhubCollector(api_key=config.finnhub_api_key.get_secret_value())
+ rss = RSSCollector()
+ sec = SecEdgarCollector()
+ truth = TruthSocialCollector()
+ reddit = RedditCollector()
+ fear_greed = FearGreedCollector()
+ fed = FedCollector()
+
+ news_collectors = [finnhub, rss, sec, truth, reddit, fed]
+
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
+ log.info(
+ "starting",
+ collectors=[c.name for c in news_collectors],
+ poll_interval=config.news_poll_interval,
+ aggregate_interval=config.sentiment_aggregate_interval,
+ )
+
+ try:
+ tasks = [
+ asyncio.create_task(
+ run_collector_loop(collector, db, broker, log),
+ name=f"collector-{collector.name}",
+ )
+ for collector in news_collectors
+ ]
+ tasks.append(
+ asyncio.create_task(
+ run_fear_greed_loop(fear_greed, db, log),
+ name="fear-greed-loop",
+ )
+ )
+ tasks.append(
+ asyncio.create_task(
+ run_aggregator_loop(db, config.sentiment_aggregate_interval, log),
+ name="aggregator-loop",
+ )
+ )
+ await shutdown.wait()
+ except Exception as exc:
+ log.error("fatal_error", error=str(exc), exc_info=True)
+ await notifier.send_error(str(exc), "news-collector")
+ raise
+ finally:
+ metrics.service_up.labels(service="news-collector").set(0)
+ for task in tasks:
+ task.cancel()
+ await notifier.close()
+ await broker.close()
+ await db.close()
+
+
+def main() -> None:
+ asyncio.run(run())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/services/news-collector/tests/__init__.py b/services/news-collector/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/news-collector/tests/__init__.py
diff --git a/services/news-collector/tests/test_fear_greed.py b/services/news-collector/tests/test_fear_greed.py
new file mode 100644
index 0000000..e8bd8f0
--- /dev/null
+++ b/services/news-collector/tests/test_fear_greed.py
@@ -0,0 +1,49 @@
+"""Tests for CNN Fear & Greed Index collector."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from news_collector.collectors.fear_greed import FearGreedCollector
+
+
+@pytest.fixture
+def collector():
+ return FearGreedCollector()
+
+
+def test_collector_name(collector):
+ assert collector.name == "fear_greed"
+ assert collector.poll_interval == 3600
+
+
+async def test_is_available(collector):
+ assert await collector.is_available() is True
+
+
+async def test_collect_parses_api_response(collector):
+ mock_data = {
+ "fear_and_greed": {
+ "score": 45.0,
+ "rating": "Fear",
+ "timestamp": "2026-04-02T12:00:00+00:00",
+ }
+ }
+ with patch.object(collector, "_fetch_index", new_callable=AsyncMock, return_value=mock_data):
+ result = await collector.collect()
+ assert result.fear_greed == 45
+ assert result.fear_greed_label == "Fear"
+
+
+async def test_collect_returns_none_on_failure(collector):
+ with patch.object(collector, "_fetch_index", new_callable=AsyncMock, return_value=None):
+ result = await collector.collect()
+ assert result is None
+
+
+def test_classify_label():
+ c = FearGreedCollector()
+ assert c._classify(10) == "Extreme Fear"
+ assert c._classify(30) == "Fear"
+ assert c._classify(50) == "Neutral"
+ assert c._classify(70) == "Greed"
+ assert c._classify(85) == "Extreme Greed"
diff --git a/services/news-collector/tests/test_fed.py b/services/news-collector/tests/test_fed.py
new file mode 100644
index 0000000..7f1c46c
--- /dev/null
+++ b/services/news-collector/tests/test_fed.py
@@ -0,0 +1,38 @@
+"""Tests for Federal Reserve collector."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from news_collector.collectors.fed import FedCollector
+
+
+@pytest.fixture
+def collector():
+ return FedCollector()
+
+
+def test_collector_name(collector):
+ assert collector.name == "fed"
+ assert collector.poll_interval == 3600
+
+
+async def test_is_available(collector):
+ assert await collector.is_available() is True
+
+
+async def test_collect_parses_rss(collector):
+ mock_entries = [
+ {
+ "title": "Federal Reserve issues FOMC statement",
+ "link": "https://www.federalreserve.gov/newsevents/pressreleases/monetary20260402a.htm",
+ "published_parsed": (2026, 4, 2, 14, 0, 0, 0, 0, 0),
+ "summary": "The Federal Open Market Committee decided to maintain the target range...",
+ },
+ ]
+ with patch.object(
+ collector, "_fetch_fed_rss", new_callable=AsyncMock, return_value=mock_entries
+ ):
+ items = await collector.collect()
+ assert len(items) == 1
+ assert items[0].source == "fed"
+ assert items[0].category.value == "fed"
diff --git a/services/news-collector/tests/test_finnhub.py b/services/news-collector/tests/test_finnhub.py
new file mode 100644
index 0000000..3af65b8
--- /dev/null
+++ b/services/news-collector/tests/test_finnhub.py
@@ -0,0 +1,67 @@
+"""Tests for Finnhub news collector."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from news_collector.collectors.finnhub import FinnhubCollector
+
+
+@pytest.fixture
+def collector():
+ return FinnhubCollector(api_key="test_key")
+
+
+def test_collector_name(collector):
+ assert collector.name == "finnhub"
+ assert collector.poll_interval == 300
+
+
+async def test_is_available_with_key(collector):
+ assert await collector.is_available() is True
+
+
+async def test_is_available_without_key():
+ c = FinnhubCollector(api_key="")
+ assert await c.is_available() is False
+
+
+async def test_collect_parses_response(collector):
+ mock_response = [
+ {
+ "category": "top news",
+ "datetime": 1711929600,
+ "headline": "AAPL beats earnings",
+ "id": 12345,
+ "related": "AAPL",
+ "source": "MarketWatch",
+ "summary": "Apple reported better than expected...",
+ "url": "https://example.com/article",
+ },
+ {
+ "category": "top news",
+ "datetime": 1711929000,
+ "headline": "Fed holds rates steady",
+ "id": 12346,
+ "related": "",
+ "source": "Reuters",
+ "summary": "The Federal Reserve...",
+ "url": "https://example.com/fed",
+ },
+ ]
+
+ with patch.object(collector, "_fetch_news", new_callable=AsyncMock, return_value=mock_response):
+ items = await collector.collect()
+
+ assert len(items) == 2
+ assert items[0].source == "finnhub"
+ assert items[0].headline == "AAPL beats earnings"
+ assert items[0].symbols == ["AAPL"]
+ assert items[0].url == "https://example.com/article"
+ assert isinstance(items[0].sentiment, float)
+ assert items[1].symbols == []
+
+
+async def test_collect_handles_empty_response(collector):
+ with patch.object(collector, "_fetch_news", new_callable=AsyncMock, return_value=[]):
+ items = await collector.collect()
+ assert items == []
diff --git a/services/news-collector/tests/test_main.py b/services/news-collector/tests/test_main.py
new file mode 100644
index 0000000..f85569a
--- /dev/null
+++ b/services/news-collector/tests/test_main.py
@@ -0,0 +1,41 @@
+"""Tests for news collector scheduler."""
+
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock
+
+from news_collector.main import run_collector_once
+
+from shared.models import NewsCategory, NewsItem
+
+
+async def test_run_collector_once_stores_and_publishes():
+ mock_item = NewsItem(
+ source="test",
+ headline="Test news",
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
+ sentiment=0.5,
+ category=NewsCategory.MACRO,
+ )
+ mock_collector = MagicMock()
+ mock_collector.name = "test"
+ mock_collector.collect = AsyncMock(return_value=[mock_item])
+ mock_db = MagicMock()
+ mock_db.insert_news_item = AsyncMock()
+ mock_broker = MagicMock()
+ mock_broker.publish = AsyncMock()
+
+ count = await run_collector_once(mock_collector, mock_db, mock_broker)
+ assert count == 1
+ mock_db.insert_news_item.assert_called_once_with(mock_item)
+ mock_broker.publish.assert_called_once()
+
+
+async def test_run_collector_once_handles_empty():
+ mock_collector = MagicMock()
+ mock_collector.name = "test"
+ mock_collector.collect = AsyncMock(return_value=[])
+ mock_db = MagicMock()
+ mock_broker = MagicMock()
+
+ count = await run_collector_once(mock_collector, mock_db, mock_broker)
+ assert count == 0
diff --git a/services/news-collector/tests/test_reddit.py b/services/news-collector/tests/test_reddit.py
new file mode 100644
index 0000000..31b1dc1
--- /dev/null
+++ b/services/news-collector/tests/test_reddit.py
@@ -0,0 +1,64 @@
+"""Tests for Reddit collector."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from news_collector.collectors.reddit import RedditCollector
+
+
+@pytest.fixture
+def collector():
+ return RedditCollector()
+
+
+def test_collector_name(collector):
+ assert collector.name == "reddit"
+ assert collector.poll_interval == 900
+
+
+async def test_is_available(collector):
+ assert await collector.is_available() is True
+
+
+async def test_collect_parses_posts(collector):
+ mock_posts = [
+ {
+ "data": {
+ "title": "NVDA to the moon! AI demand is insane",
+ "selftext": "Just loaded up on NVDA calls",
+ "url": "https://reddit.com/r/wallstreetbets/123",
+ "created_utc": 1711929600,
+ "score": 500,
+ "num_comments": 200,
+ "subreddit": "wallstreetbets",
+ }
+ },
+ ]
+ with patch.object(
+ collector, "_fetch_subreddit", new_callable=AsyncMock, return_value=mock_posts
+ ):
+ items = await collector.collect()
+ assert len(items) >= 1
+ assert items[0].source == "reddit"
+ assert items[0].category.value == "social"
+
+
+async def test_collect_filters_low_score(collector):
+ mock_posts = [
+ {
+ "data": {
+ "title": "Random question",
+ "selftext": "",
+ "url": "https://reddit.com/456",
+ "created_utc": 1711929600,
+ "score": 3,
+ "num_comments": 1,
+ "subreddit": "stocks",
+ }
+ },
+ ]
+ with patch.object(
+ collector, "_fetch_subreddit", new_callable=AsyncMock, return_value=mock_posts
+ ):
+ items = await collector.collect()
+ assert items == []
diff --git a/services/news-collector/tests/test_rss.py b/services/news-collector/tests/test_rss.py
new file mode 100644
index 0000000..7242c75
--- /dev/null
+++ b/services/news-collector/tests/test_rss.py
@@ -0,0 +1,47 @@
+"""Tests for RSS news collector."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from news_collector.collectors.rss import RSSCollector
+
+
+@pytest.fixture
+def collector():
+ return RSSCollector()
+
+
+def test_collector_name(collector):
+ assert collector.name == "rss"
+ assert collector.poll_interval == 600
+
+
+async def test_is_available(collector):
+ assert await collector.is_available() is True
+
+
+async def test_collect_parses_feed(collector):
+ mock_feed = {
+ "entries": [
+ {
+ "title": "NVDA surges on AI demand",
+ "link": "https://example.com/nvda",
+ "published_parsed": (2026, 4, 2, 12, 0, 0, 0, 0, 0),
+ "summary": "Nvidia stock jumped 5%...",
+ },
+ {
+ "title": "Markets rally on jobs data",
+ "link": "https://example.com/market",
+ "published_parsed": (2026, 4, 2, 11, 0, 0, 0, 0, 0),
+ "summary": "The S&P 500 rose...",
+ },
+ ],
+ }
+
+ with patch.object(collector, "_fetch_feeds", new_callable=AsyncMock, return_value=[mock_feed]):
+ items = await collector.collect()
+
+ assert len(items) == 2
+ assert items[0].source == "rss"
+ assert items[0].headline == "NVDA surges on AI demand"
+ assert isinstance(items[0].sentiment, float)
diff --git a/services/news-collector/tests/test_sec_edgar.py b/services/news-collector/tests/test_sec_edgar.py
new file mode 100644
index 0000000..b0faf18
--- /dev/null
+++ b/services/news-collector/tests/test_sec_edgar.py
@@ -0,0 +1,58 @@
+"""Tests for SEC EDGAR filing collector."""
+
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from news_collector.collectors.sec_edgar import SecEdgarCollector
+
+
+@pytest.fixture
+def collector():
+ return SecEdgarCollector()
+
+
+def test_collector_name(collector):
+ assert collector.name == "sec_edgar"
+ assert collector.poll_interval == 1800
+
+
+async def test_is_available(collector):
+ assert await collector.is_available() is True
+
+
+async def test_collect_parses_filings(collector):
+ mock_response = {
+ "filings": {
+ "recent": {
+ "accessionNumber": ["0001234-26-000001"],
+ "filingDate": ["2026-04-02"],
+ "primaryDocument": ["filing.htm"],
+ "form": ["8-K"],
+ "primaryDocDescription": ["Current Report"],
+ }
+ },
+ "tickers": [{"ticker": "AAPL"}],
+ "name": "Apple Inc",
+ }
+
+ mock_datetime = MagicMock(spec=datetime)
+ mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=UTC)
+ mock_datetime.strptime = datetime.strptime
+
+ with patch.object(
+ collector, "_fetch_recent_filings", new_callable=AsyncMock, return_value=[mock_response]
+ ):
+ with patch("news_collector.collectors.sec_edgar.datetime", mock_datetime):
+ items = await collector.collect()
+
+ assert len(items) == 1
+ assert items[0].source == "sec_edgar"
+ assert items[0].category.value == "filing"
+ assert "AAPL" in items[0].symbols
+
+
+async def test_collect_handles_empty(collector):
+ with patch.object(collector, "_fetch_recent_filings", new_callable=AsyncMock, return_value=[]):
+ items = await collector.collect()
+ assert items == []
diff --git a/services/news-collector/tests/test_truth_social.py b/services/news-collector/tests/test_truth_social.py
new file mode 100644
index 0000000..52f1e46
--- /dev/null
+++ b/services/news-collector/tests/test_truth_social.py
@@ -0,0 +1,42 @@
+"""Tests for Truth Social collector."""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from news_collector.collectors.truth_social import TruthSocialCollector
+
+
+@pytest.fixture
+def collector():
+ return TruthSocialCollector()
+
+
+def test_collector_name(collector):
+ assert collector.name == "truth_social"
+ assert collector.poll_interval == 900
+
+
+async def test_is_available(collector):
+ assert await collector.is_available() is True
+
+
+async def test_collect_parses_posts(collector):
+ mock_posts = [
+ {
+ "content": "<p>We are imposing 25% tariffs on all steel imports!</p>",
+ "created_at": "2026-04-02T12:00:00.000Z",
+ "url": "https://truthsocial.com/@realDonaldTrump/12345",
+ "id": "12345",
+ },
+ ]
+ with patch.object(collector, "_fetch_posts", new_callable=AsyncMock, return_value=mock_posts):
+ items = await collector.collect()
+ assert len(items) == 1
+ assert items[0].source == "truth_social"
+ assert items[0].category.value == "policy"
+
+
+async def test_collect_handles_empty(collector):
+ with patch.object(collector, "_fetch_posts", new_callable=AsyncMock, return_value=[]):
+ items = await collector.collect()
+ assert items == []
diff --git a/services/order-executor/Dockerfile b/services/order-executor/Dockerfile
index bc8b21c..376afec 100644
--- a/services/order-executor/Dockerfile
+++ b/services/order-executor/Dockerfile
@@ -1,8 +1,15 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/order-executor/ services/order-executor/
RUN pip install --no-cache-dir ./services/order-executor
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "order_executor.main"]
diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py
index a71e762..fd502cd 100644
--- a/services/order-executor/src/order_executor/executor.py
+++ b/services/order-executor/src/order_executor/executor.py
@@ -1,18 +1,18 @@
"""Order execution logic."""
-import structlog
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from typing import Any, Optional
+from typing import Any
+
+import structlog
+from order_executor.risk_manager import RiskManager
from shared.broker import RedisBroker
from shared.db import Database
from shared.events import OrderEvent
from shared.models import Order, OrderStatus, OrderType, Signal
from shared.notifier import TelegramNotifier
-from order_executor.risk_manager import RiskManager
-
logger = structlog.get_logger()
@@ -35,7 +35,7 @@ class OrderExecutor:
self.notifier = notifier
self.dry_run = dry_run
- async def execute(self, signal: Signal) -> Optional[Order]:
+ async def execute(self, signal: Signal) -> Order | None:
"""Run risk checks and place an order for the given signal."""
# Fetch buying power from Alpaca
balance = await self.exchange.get_buying_power()
@@ -71,7 +71,7 @@ class OrderExecutor:
if self.dry_run:
order.status = OrderStatus.FILLED
- order.filled_at = datetime.now(timezone.utc)
+ order.filled_at = datetime.now(UTC)
logger.info(
"order_filled_dry_run",
side=str(order.side),
@@ -87,7 +87,7 @@ class OrderExecutor:
type="market",
)
order.status = OrderStatus.FILLED
- order.filled_at = datetime.now(timezone.utc)
+ order.filled_at = datetime.now(UTC)
logger.info(
"order_filled",
side=str(order.side),
diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py
index 51ab286..99f88e1 100644
--- a/services/order-executor/src/order_executor/main.py
+++ b/services/order-executor/src/order_executor/main.py
@@ -3,6 +3,11 @@
import asyncio
from decimal import Decimal
+import aiohttp
+
+from order_executor.config import ExecutorConfig
+from order_executor.executor import OrderExecutor
+from order_executor.risk_manager import RiskManager
from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
from shared.db import Database
@@ -11,10 +16,7 @@ from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
-
-from order_executor.config import ExecutorConfig
-from order_executor.executor import OrderExecutor
-from order_executor.risk_manager import RiskManager
+from shared.shutdown import GracefulShutdown
# Health check port: base + 2
HEALTH_PORT_OFFSET = 2
@@ -26,18 +28,18 @@ async def run() -> None:
metrics = ServiceMetrics("order_executor")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
alpaca = AlpacaClient(
- api_key=config.alpaca_api_key,
- api_secret=config.alpaca_api_secret,
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
paper=config.alpaca_paper,
)
@@ -83,6 +85,9 @@ async def run() -> None:
await broker.ensure_group(stream, GROUP)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("started", stream=stream, dry_run=config.dry_run)
try:
@@ -94,10 +99,15 @@ async def run() -> None:
if event.type == EventType.SIGNAL:
await executor.execute(event.data)
await broker.ack(stream, GROUP, msg_id)
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("pending_network_error", error=str(exc), msg_id=msg_id)
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("pending_parse_error", error=str(exc), msg_id=msg_id)
+ await broker.ack(stream, GROUP, msg_id)
except Exception as exc:
- log.error("pending_failed", error=str(exc), msg_id=msg_id)
+ log.error("pending_failed", error=str(exc), msg_id=msg_id, exc_info=True)
- while True:
+ while not shutdown.is_shutting_down:
messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000)
for msg_id, msg in messages:
try:
@@ -110,8 +120,19 @@ async def run() -> None:
service="order-executor", event_type="signal"
).inc()
await broker.ack(stream, GROUP, msg_id)
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("process_network_error", error=str(exc))
+ metrics.errors_total.labels(
+ service="order-executor", error_type="network"
+ ).inc()
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("process_parse_error", error=str(exc))
+ await broker.ack(stream, GROUP, msg_id)
+ metrics.errors_total.labels(
+ service="order-executor", error_type="validation"
+ ).inc()
except Exception as exc:
- log.error("process_failed", error=str(exc))
+ log.error("process_failed", error=str(exc), exc_info=True)
metrics.errors_total.labels(
service="order-executor", error_type="processing"
).inc()
diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py
index 5a05746..811a862 100644
--- a/services/order-executor/src/order_executor/risk_manager.py
+++ b/services/order-executor/src/order_executor/risk_manager.py
@@ -1,12 +1,12 @@
"""Risk management for order execution."""
+import math
+from collections import deque
from dataclasses import dataclass
-from datetime import datetime, timezone, timedelta
+from datetime import UTC, datetime, timedelta
from decimal import Decimal
-from collections import deque
-import math
-from shared.models import Signal, OrderSide, Position
+from shared.models import OrderSide, Position, Signal
@dataclass
@@ -123,15 +123,13 @@ class RiskManager:
else:
self._consecutive_losses += 1
if self._consecutive_losses >= self._max_consecutive_losses:
- self._paused_until = datetime.now(timezone.utc) + timedelta(
- minutes=self._loss_pause_minutes
- )
+ self._paused_until = datetime.now(UTC) + timedelta(minutes=self._loss_pause_minutes)
def is_paused(self) -> bool:
"""Check if trading is paused due to consecutive losses."""
if self._paused_until is None:
return False
- if datetime.now(timezone.utc) >= self._paused_until:
+ if datetime.now(UTC) >= self._paused_until:
self._paused_until = None
self._consecutive_losses = 0
return False
@@ -233,9 +231,9 @@ class RiskManager:
mean_a = sum(returns_a) / len(returns_a)
mean_b = sum(returns_b) / len(returns_b)
- cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len(
- returns_a
- )
+ cov = sum(
+ (a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b, strict=True)
+ ) / len(returns_a)
std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a))
std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b))
@@ -280,7 +278,11 @@ class RiskManager:
min_len = min(len(r) for r in all_returns)
portfolio_returns = []
for i in range(min_len):
- pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i)
+ pr = sum(
+ w * r[-(min_len - i)]
+ for w, r in zip(weights, all_returns, strict=False)
+ if len(r) > i
+ )
portfolio_returns.append(pr)
if not portfolio_returns:
diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py
index dd823d7..cda6b72 100644
--- a/services/order-executor/tests/test_executor.py
+++ b/services/order-executor/tests/test_executor.py
@@ -4,11 +4,11 @@ from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import pytest
-
-from shared.models import OrderSide, OrderStatus, Signal
from order_executor.executor import OrderExecutor
from order_executor.risk_manager import RiskCheckResult, RiskManager
+from shared.models import OrderSide, OrderStatus, Signal
+
def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal:
return Signal(
diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py
index 00a9ab4..66e769c 100644
--- a/services/order-executor/tests/test_risk_manager.py
+++ b/services/order-executor/tests/test_risk_manager.py
@@ -2,12 +2,12 @@
from decimal import Decimal
+from order_executor.risk_manager import RiskManager
from shared.models import OrderSide, Position, Signal
-from order_executor.risk_manager import RiskManager
-def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "BTC/USDT") -> Signal:
+def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal:
return Signal(
strategy="test",
symbol=symbol,
@@ -93,7 +93,7 @@ def test_risk_check_rejects_insufficient_balance():
def test_trailing_stop_set_and_trigger():
"""Trailing stop should trigger when price drops below stop level."""
rm = make_risk_manager(trailing_stop_pct="5")
- rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+ rm.set_trailing_stop("AAPL", Decimal("100"))
signal = make_signal(side=OrderSide.BUY, price="94", quantity="0.01")
result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0"))
@@ -104,10 +104,10 @@ def test_trailing_stop_set_and_trigger():
def test_trailing_stop_updates_highest_price():
"""Trailing stop should track the highest price seen."""
rm = make_risk_manager(trailing_stop_pct="5")
- rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+ rm.set_trailing_stop("AAPL", Decimal("100"))
# Price rises to 120 => stop at 114
- rm.update_price("BTC/USDT", Decimal("120"))
+ rm.update_price("AAPL", Decimal("120"))
# Price at 115 is above stop (114), should be allowed
signal = make_signal(side=OrderSide.BUY, price="115", quantity="0.01")
@@ -124,7 +124,7 @@ def test_trailing_stop_updates_highest_price():
def test_trailing_stop_not_triggered_above_stop():
"""Trailing stop should not trigger when price is above stop level."""
rm = make_risk_manager(trailing_stop_pct="5")
- rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+ rm.set_trailing_stop("AAPL", Decimal("100"))
# Price at 96 is above stop (95), should be allowed
signal = make_signal(side=OrderSide.BUY, price="96", quantity="0.01")
@@ -140,11 +140,11 @@ def test_max_open_positions_check():
rm = make_risk_manager(max_open_positions=2)
positions = {
- "BTC/USDT": make_position("BTC/USDT", "1", "100", "100"),
- "ETH/USDT": make_position("ETH/USDT", "10", "50", "50"),
+ "AAPL": make_position("AAPL", "1", "100", "100"),
+ "MSFT": make_position("MSFT", "10", "50", "50"),
}
- signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="SOL/USDT")
+ signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="TSLA")
result = rm.check(signal, balance=Decimal("10000"), positions=positions, daily_pnl=Decimal("0"))
assert result.allowed is False
assert result.reason == "Max open positions reached"
@@ -158,14 +158,14 @@ def test_volatility_calculation():
rm = make_risk_manager(volatility_lookback=5)
# No history yet
- assert rm.get_volatility("BTC/USDT") is None
+ assert rm.get_volatility("AAPL") is None
# Feed prices
prices = [100, 102, 98, 105, 101]
for p in prices:
- rm.update_price("BTC/USDT", Decimal(str(p)))
+ rm.update_price("AAPL", Decimal(str(p)))
- vol = rm.get_volatility("BTC/USDT")
+ vol = rm.get_volatility("AAPL")
assert vol is not None
assert vol > 0
@@ -177,9 +177,9 @@ def test_position_size_with_volatility_scaling():
# Feed volatile prices
prices = [100, 120, 80, 130, 70]
for p in prices:
- rm.update_price("BTC/USDT", Decimal(str(p)))
+ rm.update_price("AAPL", Decimal(str(p)))
- size = rm.calculate_position_size("BTC/USDT", Decimal("10000"))
+ size = rm.calculate_position_size("AAPL", Decimal("10000"))
base = Decimal("10000") * Decimal("0.1")
# High volatility should reduce size below base
@@ -192,9 +192,9 @@ def test_position_size_without_scaling():
prices = [100, 120, 80, 130, 70]
for p in prices:
- rm.update_price("BTC/USDT", Decimal(str(p)))
+ rm.update_price("AAPL", Decimal(str(p)))
- size = rm.calculate_position_size("BTC/USDT", Decimal("10000"))
+ size = rm.calculate_position_size("AAPL", Decimal("10000"))
base = Decimal("10000") * Decimal("0.1")
assert size == base
@@ -211,8 +211,8 @@ def test_portfolio_exposure_check_passes():
max_portfolio_exposure=0.8,
)
positions = {
- "BTCUSDT": Position(
- symbol="BTCUSDT",
+ "AAPL": Position(
+ symbol="AAPL",
quantity=Decimal("0.01"),
avg_entry_price=Decimal("50000"),
current_price=Decimal("50000"),
@@ -230,8 +230,8 @@ def test_portfolio_exposure_check_rejects():
max_portfolio_exposure=0.3,
)
positions = {
- "BTCUSDT": Position(
- symbol="BTCUSDT",
+ "AAPL": Position(
+ symbol="AAPL",
quantity=Decimal("1"),
avg_entry_price=Decimal("50000"),
current_price=Decimal("50000"),
@@ -263,10 +263,10 @@ def test_var_calculation():
daily_loss_limit_pct=Decimal("10"),
)
for i in range(30):
- rm.update_price("BTCUSDT", Decimal(str(100 + (i % 5) - 2)))
+ rm.update_price("AAPL", Decimal(str(100 + (i % 5) - 2)))
positions = {
- "BTCUSDT": Position(
- symbol="BTCUSDT",
+ "AAPL": Position(
+ symbol="AAPL",
quantity=Decimal("1"),
avg_entry_price=Decimal("100"),
current_price=Decimal("100"),
@@ -357,7 +357,7 @@ def test_drawdown_check_rejects_in_check():
rm.update_balance(Decimal("10000"))
signal = Signal(
strategy="test",
- symbol="BTC/USDT",
+ symbol="AAPL",
side=OrderSide.BUY,
price=Decimal("50000"),
quantity=Decimal("0.01"),
diff --git a/services/portfolio-manager/Dockerfile b/services/portfolio-manager/Dockerfile
index b1a7681..0fa3f35 100644
--- a/services/portfolio-manager/Dockerfile
+++ b/services/portfolio-manager/Dockerfile
@@ -1,8 +1,15 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/portfolio-manager/ services/portfolio-manager/
RUN pip install --no-cache-dir ./services/portfolio-manager
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "portfolio_manager.main"]
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py
index a6823ae..f885aa8 100644
--- a/services/portfolio-manager/src/portfolio_manager/main.py
+++ b/services/portfolio-manager/src/portfolio_manager/main.py
@@ -2,6 +2,10 @@
import asyncio
+import sqlalchemy.exc
+
+from portfolio_manager.config import PortfolioConfig
+from portfolio_manager.portfolio import PortfolioTracker
from shared.broker import RedisBroker
from shared.db import Database
from shared.events import Event, OrderEvent
@@ -9,9 +13,7 @@ from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
-
-from portfolio_manager.config import PortfolioConfig
-from portfolio_manager.portfolio import PortfolioTracker
+from shared.shutdown import GracefulShutdown
ORDERS_STREAM = "orders"
@@ -51,8 +53,12 @@ async def snapshot_loop(
while True:
try:
await save_snapshot(db, tracker, notifier, log)
+ except (sqlalchemy.exc.OperationalError, ConnectionError, TimeoutError) as exc:
+ log.warning("snapshot_db_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("snapshot_data_error", error=str(exc))
except Exception as exc:
- log.error("snapshot_failed", error=str(exc))
+ log.error("snapshot_failed", error=str(exc), exc_info=True)
await asyncio.sleep(interval_hours * 3600)
@@ -61,10 +67,10 @@ async def run() -> None:
log = setup_logging("portfolio-manager", config.log_level, config.log_format)
metrics = ServiceMetrics("portfolio_manager")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id
+ bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id
)
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
tracker = PortfolioTracker()
health = HealthCheckServer(
@@ -76,13 +82,16 @@ async def run() -> None:
await health.start()
metrics.service_up.labels(service="portfolio-manager").set(1)
- db = Database(config.database_url)
+ db = Database(config.database_url.get_secret_value())
await db.connect()
snapshot_task = asyncio.create_task(
snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log)
)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
GROUP = "portfolio-manager"
CONSUMER = "portfolio-1"
log.info("service_started", stream=ORDERS_STREAM)
@@ -108,12 +117,16 @@ async def run() -> None:
service="portfolio-manager", event_type="order"
).inc()
await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("pending_parse_error", error=str(exc), msg_id=msg_id)
+ await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ metrics.errors_total.labels(service="portfolio-manager", error_type="validation").inc()
except Exception as exc:
- log.error("pending_process_failed", error=str(exc), msg_id=msg_id)
+ log.error("pending_process_failed", error=str(exc), msg_id=msg_id, exc_info=True)
metrics.errors_total.labels(service="portfolio-manager", error_type="processing").inc()
try:
- while True:
+ while not shutdown.is_shutting_down:
messages = await broker.read_group(ORDERS_STREAM, GROUP, CONSUMER, count=10, block=1000)
for msg_id, msg in messages:
try:
@@ -134,13 +147,21 @@ async def run() -> None:
service="portfolio-manager", event_type="order"
).inc()
await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("message_parse_error", error=str(exc), msg_id=msg_id)
+ await broker.ack(ORDERS_STREAM, GROUP, msg_id)
+ metrics.errors_total.labels(
+ service="portfolio-manager", error_type="validation"
+ ).inc()
except Exception as exc:
- log.exception("message_processing_failed", error=str(exc), msg_id=msg_id)
+ log.error(
+ "message_processing_failed", error=str(exc), msg_id=msg_id, exc_info=True
+ )
metrics.errors_total.labels(
service="portfolio-manager", error_type="processing"
).inc()
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "portfolio-manager")
raise
finally:
diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py
index 768e071..c8a6894 100644
--- a/services/portfolio-manager/tests/test_portfolio.py
+++ b/services/portfolio-manager/tests/test_portfolio.py
@@ -2,15 +2,16 @@
from decimal import Decimal
-from shared.models import Order, OrderSide, OrderStatus, OrderType
from portfolio_manager.portfolio import PortfolioTracker
+from shared.models import Order, OrderSide, OrderStatus, OrderType
+
def make_order(side: OrderSide, price: str, quantity: str) -> Order:
"""Helper to create a filled Order."""
return Order(
signal_id="test-signal",
- symbol="BTC/USDT",
+ symbol="AAPL",
side=side,
type=OrderType.MARKET,
price=Decimal(price),
@@ -24,7 +25,7 @@ def test_portfolio_add_buy_order() -> None:
order = make_order(OrderSide.BUY, "50000", "0.1")
tracker.apply_order(order)
- position = tracker.get_position("BTC/USDT")
+ position = tracker.get_position("AAPL")
assert position is not None
assert position.quantity == Decimal("0.1")
assert position.avg_entry_price == Decimal("50000")
@@ -35,7 +36,7 @@ def test_portfolio_add_multiple_buys() -> None:
tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.1"))
tracker.apply_order(make_order(OrderSide.BUY, "52000", "0.1"))
- position = tracker.get_position("BTC/USDT")
+ position = tracker.get_position("AAPL")
assert position is not None
assert position.quantity == Decimal("0.2")
assert position.avg_entry_price == Decimal("51000")
@@ -46,7 +47,7 @@ def test_portfolio_sell_reduces_position() -> None:
tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.2"))
tracker.apply_order(make_order(OrderSide.SELL, "55000", "0.1"))
- position = tracker.get_position("BTC/USDT")
+ position = tracker.get_position("AAPL")
assert position is not None
assert position.quantity == Decimal("0.1")
assert position.avg_entry_price == Decimal("50000")
@@ -54,7 +55,7 @@ def test_portfolio_sell_reduces_position() -> None:
def test_portfolio_no_position_returns_none() -> None:
tracker = PortfolioTracker()
- position = tracker.get_position("ETH/USDT")
+ position = tracker.get_position("MSFT")
assert position is None
@@ -66,7 +67,7 @@ def test_realized_pnl_on_sell() -> None:
tracker.apply_order(
Order(
signal_id="s1",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.BUY,
type=OrderType.MARKET,
price=Decimal("50000"),
@@ -80,7 +81,7 @@ def test_realized_pnl_on_sell() -> None:
tracker.apply_order(
Order(
signal_id="s2",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.SELL,
type=OrderType.MARKET,
price=Decimal("55000"),
@@ -98,7 +99,7 @@ def test_realized_pnl_on_loss() -> None:
tracker.apply_order(
Order(
signal_id="s1",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.BUY,
type=OrderType.MARKET,
price=Decimal("50000"),
@@ -109,7 +110,7 @@ def test_realized_pnl_on_loss() -> None:
tracker.apply_order(
Order(
signal_id="s2",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.SELL,
type=OrderType.MARKET,
price=Decimal("45000"),
@@ -128,7 +129,7 @@ def test_realized_pnl_accumulates() -> None:
tracker.apply_order(
Order(
signal_id="s1",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.BUY,
type=OrderType.MARKET,
price=Decimal("50000"),
@@ -141,7 +142,7 @@ def test_realized_pnl_accumulates() -> None:
tracker.apply_order(
Order(
signal_id="s2",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.SELL,
type=OrderType.MARKET,
price=Decimal("55000"),
@@ -154,7 +155,7 @@ def test_realized_pnl_accumulates() -> None:
tracker.apply_order(
Order(
signal_id="s3",
- symbol="BTCUSDT",
+ symbol="AAPL",
side=OrderSide.SELL,
type=OrderType.MARKET,
price=Decimal("60000"),
diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py
index a464599..f2026e2 100644
--- a/services/portfolio-manager/tests/test_snapshot.py
+++ b/services/portfolio-manager/tests/test_snapshot.py
@@ -1,9 +1,10 @@
"""Tests for save_snapshot in portfolio-manager."""
-import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
+import pytest
+
from shared.models import Position
@@ -13,7 +14,7 @@ class TestSaveSnapshot:
from portfolio_manager.main import save_snapshot
pos = Position(
- symbol="BTCUSDT",
+ symbol="AAPL",
quantity=Decimal("0.5"),
avg_entry_price=Decimal("50000"),
current_price=Decimal("52000"),
diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile
index de635dc..f1484e9 100644
--- a/services/strategy-engine/Dockerfile
+++ b/services/strategy-engine/Dockerfile
@@ -1,9 +1,16 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/strategy-engine/ services/strategy-engine/
RUN pip install --no-cache-dir ./services/strategy-engine
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
COPY services/strategy-engine/strategies/ /app/strategies/
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "strategy_engine.main"]
diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml
index 4f5b6be..e4bfb12 100644
--- a/services/strategy-engine/pyproject.toml
+++ b/services/strategy-engine/pyproject.toml
@@ -3,11 +3,7 @@ name = "strategy-engine"
version = "0.1.0"
description = "Plugin-based strategy execution engine"
requires-python = ">=3.12"
-dependencies = [
- "pandas>=2.0",
- "numpy>=1.20",
- "trading-shared",
-]
+dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py
index e3a49c2..9fd9c49 100644
--- a/services/strategy-engine/src/strategy_engine/config.py
+++ b/services/strategy-engine/src/strategy_engine/config.py
@@ -4,6 +4,6 @@ from shared.config import Settings
class StrategyConfig(Settings):
- symbols: list[str] = ["BTC/USDT"]
+ symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]
timeframes: list[str] = ["1m"]
strategy_params: dict = {}
diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py
index d401aee..4b2c468 100644
--- a/services/strategy-engine/src/strategy_engine/engine.py
+++ b/services/strategy-engine/src/strategy_engine/engine.py
@@ -2,11 +2,11 @@
import logging
-from shared.broker import RedisBroker
-from shared.events import CandleEvent, SignalEvent, Event
-
from strategies.base import BaseStrategy
+from shared.broker import RedisBroker
+from shared.events import CandleEvent, Event, SignalEvent
+
logger = logging.getLogger(__name__)
@@ -26,7 +26,7 @@ class StrategyEngine:
try:
event = Event.from_dict(raw)
except Exception as exc:
- logger.warning("Failed to parse event: %s – %s", raw, exc)
+ logger.warning("Failed to parse event: %s - %s", raw, exc)
continue
if not isinstance(event, CandleEvent):
diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py
index 30de528..3d73058 100644
--- a/services/strategy-engine/src/strategy_engine/main.py
+++ b/services/strategy-engine/src/strategy_engine/main.py
@@ -1,17 +1,25 @@
"""Strategy Engine Service entry point."""
import asyncio
+import zoneinfo
+from datetime import datetime
from pathlib import Path
+import aiohttp
+
+from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
+from shared.db import Database
from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
-
+from shared.sentiment_models import MarketSentiment
+from shared.shutdown import GracefulShutdown
from strategy_engine.config import StrategyConfig
from strategy_engine.engine import StrategyEngine
from strategy_engine.plugin_loader import load_strategies
+from strategy_engine.stock_selector import StockSelector
# The strategies directory lives alongside the installed package
STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies"
@@ -30,23 +38,74 @@ async def process_symbol(engine: StrategyEngine, stream: str, log) -> None:
last_id = await engine.process_once(stream, last_id)
+async def run_stock_selector(
+ selector: StockSelector,
+ notifier: TelegramNotifier,
+ db: Database,
+ config: StrategyConfig,
+ log,
+) -> None:
+ """Run the stock selector once per day at the configured time."""
+ et = zoneinfo.ZoneInfo("America/New_York")
+
+ while True:
+ now_et = datetime.now(et)
+ target_hour, target_min = map(int, config.selector_final_time.split(":"))
+
+ if now_et.hour == target_hour and now_et.minute == target_min:
+ log.info("stock_selector_running")
+ try:
+ selections = await selector.select()
+ if selections:
+ ms_data = await db.get_latest_market_sentiment()
+ ms = None
+ if ms_data:
+ ms = MarketSentiment(**ms_data)
+ await notifier.send_stock_selection(selections, ms)
+ log.info("stock_selector_complete", picks=[s.symbol for s in selections])
+ else:
+ log.info("stock_selector_no_picks")
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("stock_selector_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("stock_selector_data_error", error=str(exc))
+ except Exception as exc:
+ log.error("stock_selector_error", error=str(exc), exc_info=True)
+ await asyncio.sleep(120) # Sleep past this minute
+ else:
+ await asyncio.sleep(30)
+
+
async def run() -> None:
config = StrategyConfig()
log = setup_logging("strategy-engine", config.log_level, config.log_format)
metrics = ServiceMetrics("strategy_engine")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
+
+ db = Database(config.database_url.get_secret_value())
+ await db.connect()
+
+ alpaca = AlpacaClient(
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
+ paper=config.alpaca_paper,
+ )
+
strategies = load_strategies(STRATEGIES_DIR)
for strategy in strategies:
params = config.strategy_params.get(strategy.name, {})
strategy.configure(params)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies])
engine = StrategyEngine(broker=broker, strategies=strategies)
@@ -67,9 +126,23 @@ async def run() -> None:
task = asyncio.create_task(process_symbol(engine, stream, log))
tasks.append(task)
- await asyncio.gather(*tasks)
+ if config.anthropic_api_key.get_secret_value():
+ selector = StockSelector(
+ db=db,
+ broker=broker,
+ alpaca=alpaca,
+ anthropic_api_key=config.anthropic_api_key.get_secret_value(),
+ anthropic_model=config.anthropic_model,
+ max_picks=config.selector_max_picks,
+ )
+ tasks.append(
+ asyncio.create_task(run_stock_selector(selector, notifier, db, config, log))
+ )
+ log.info("stock_selector_enabled", time=config.selector_final_time)
+
+ await shutdown.wait()
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "strategy-engine")
raise
finally:
@@ -78,6 +151,8 @@ async def run() -> None:
metrics.service_up.labels(service="strategy-engine").set(0)
await notifier.close()
await broker.close()
+ await alpaca.close()
+ await db.close()
def main() -> None:
diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py
index 62e4160..57680db 100644
--- a/services/strategy-engine/src/strategy_engine/plugin_loader.py
+++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py
@@ -5,7 +5,6 @@ import sys
from pathlib import Path
import yaml
-
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py
new file mode 100644
index 0000000..8657b93
--- /dev/null
+++ b/services/strategy-engine/src/strategy_engine/stock_selector.py
@@ -0,0 +1,418 @@
+"""3-stage stock selector engine: sentiment → technical → LLM."""
+
+import asyncio
+import json
+import logging
+import re
+from datetime import UTC, datetime
+
+import aiohttp
+
+from shared.alpaca import AlpacaClient
+from shared.broker import RedisBroker
+from shared.db import Database
+from shared.models import OrderSide
+from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock
+
+logger = logging.getLogger(__name__)
+
+ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
+
+
+def _extract_json_array(text: str) -> list[dict] | None:
+ """Extract a JSON array from text that may contain markdown code blocks."""
+ code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
+ if code_block:
+ raw = code_block.group(1)
+ else:
+ array_match = re.search(r"\[.*\]", text, re.DOTALL)
+ if array_match:
+ raw = array_match.group(0)
+ else:
+ raw = text.strip()
+
+ try:
+ data = json.loads(raw)
+ if isinstance(data, list):
+ return [item for item in data if isinstance(item, dict)]
+ return None
+ except (json.JSONDecodeError, TypeError):
+ return None
+
+
+def _parse_llm_selections(text: str) -> list[SelectedStock]:
+ """Parse LLM response into SelectedStock list.
+
+ Handles both bare JSON arrays and markdown code blocks.
+ Returns empty list on any parse error.
+ """
+ items = _extract_json_array(text)
+ if items is None:
+ return []
+
+ selections = []
+ for item in items:
+ try:
+ selection = SelectedStock(
+ symbol=item["symbol"],
+ side=OrderSide(item["side"]),
+ conviction=float(item["conviction"]),
+ reason=item.get("reason", ""),
+ key_news=item.get("key_news", []),
+ )
+ selections.append(selection)
+ except (KeyError, ValueError) as e:
+ logger.warning("Skipping invalid selection item: %s", e)
+ return selections
+
+
+class SentimentCandidateSource:
+ """Generates candidates from DB sentiment scores."""
+
+ def __init__(self, db: Database) -> None:
+ self._db = db
+
+ async def get_candidates(self) -> list[Candidate]:
+ rows = await self._db.get_top_symbol_scores(limit=20)
+ candidates = []
+ for row in rows:
+ composite = float(row.get("composite", 0))
+ if composite == 0:
+ continue
+ candidates.append(
+ Candidate(
+ symbol=row["symbol"],
+ source="sentiment",
+ score=composite,
+ reason=f"composite={composite:.2f}, news_count={row.get('news_count', 0)}",
+ )
+ )
+ return candidates
+
+
+class LLMCandidateSource:
+ """Generates candidates by asking Claude to analyze recent news."""
+
+ def __init__(self, db: Database, api_key: str, model: str) -> None:
+ self._db = db
+ self._api_key = api_key
+ self._model = model
+
+ async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]:
+ news_items = await self._db.get_recent_news(hours=24)
+ if not news_items:
+ return []
+
+ headlines = []
+ for item in news_items[:50]: # cap at 50 to stay within context
+ symbols = item.get("symbols", [])
+ sym_str = ", ".join(symbols) if symbols else "N/A"
+ headlines.append(f"[{sym_str}] {item['headline']}")
+
+ prompt = (
+ "You are a stock analyst. Given recent news headlines, identify the 5-10 most "
+ "actionable US stock tickers. Return ONLY a JSON array with objects having: "
+ "symbol (ticker), direction ('BUY' or 'SELL'), score (0-1), reason (brief).\n\n"
+ "Headlines:\n" + "\n".join(headlines)
+ )
+
+ own_session = session is None
+ if own_session:
+ session = aiohttp.ClientSession()
+
+ try:
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM candidate source error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
+
+ content = data.get("content", [])
+ text = ""
+ for block in content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ text += block.get("text", "")
+
+ return self._parse_candidates(text)
+ except Exception as e:
+ logger.error("LLMCandidateSource error: %s", e)
+ return []
+ finally:
+ if own_session:
+ await session.close()
+
+ def _parse_candidates(self, text: str) -> list[Candidate]:
+ items = _extract_json_array(text)
+ if items is None:
+ return []
+
+ candidates = []
+ for item in items:
+ try:
+ direction_str = item.get("direction", "BUY")
+ direction = OrderSide(direction_str)
+ except ValueError:
+ direction = None
+ candidates.append(
+ Candidate(
+ symbol=item["symbol"],
+ source="llm",
+ direction=direction,
+ score=float(item.get("score", 0.5)),
+ reason=item.get("reason", ""),
+ )
+ )
+ return candidates
+
+
+def _compute_rsi(closes: list[float], period: int = 14) -> float:
+ """Compute RSI for the last data point."""
+ if len(closes) < period + 1:
+ return 50.0 # neutral if insufficient data
+
+ deltas = [closes[i] - closes[i - 1] for i in range(1, len(closes))]
+ gains = [d if d > 0 else 0.0 for d in deltas]
+ losses = [-d if d < 0 else 0.0 for d in deltas]
+
+ avg_gain = sum(gains[:period]) / period
+ avg_loss = sum(losses[:period]) / period
+
+ for i in range(period, len(deltas)):
+ avg_gain = (avg_gain * (period - 1) + gains[i]) / period
+ avg_loss = (avg_loss * (period - 1) + losses[i]) / period
+
+ if avg_loss == 0:
+ return 100.0
+ rs = avg_gain / avg_loss
+ return 100.0 - (100.0 / (1.0 + rs))
+
+
+class StockSelector:
+ """Orchestrates the 3-stage stock selection pipeline."""
+
+ def __init__(
+ self,
+ db: Database,
+ broker: RedisBroker,
+ alpaca: AlpacaClient,
+ anthropic_api_key: str,
+ anthropic_model: str = "claude-sonnet-4-20250514",
+ max_picks: int = 3,
+ ) -> None:
+ self._db = db
+ self._broker = broker
+ self._alpaca = alpaca
+ self._api_key = anthropic_api_key
+ self._model = anthropic_model
+ self._max_picks = max_picks
+ self._http_session: aiohttp.ClientSession | None = None
+ self._session_lock = asyncio.Lock()
+
+ async def _ensure_session(self) -> aiohttp.ClientSession:
+ async with self._session_lock:
+ if self._http_session is None or self._http_session.closed:
+ self._http_session = aiohttp.ClientSession()
+ return self._http_session
+
+ async def close(self) -> None:
+ if self._http_session and not self._http_session.closed:
+ await self._http_session.close()
+
+ async def select(self) -> list[SelectedStock]:
+ """Run the full 3-stage pipeline and return selected stocks."""
+ # Market gate: check sentiment
+ sentiment_data = await self._db.get_latest_market_sentiment()
+ if sentiment_data is None:
+ logger.warning("No market sentiment data; skipping selection")
+ return []
+
+ market_sentiment = MarketSentiment(**sentiment_data)
+ if market_sentiment.market_regime == "risk_off":
+ logger.info("Market is risk_off; skipping stock selection")
+ return []
+
+ # Stage 1: gather candidates from both sources
+ sentiment_source = SentimentCandidateSource(self._db)
+ llm_source = LLMCandidateSource(self._db, self._api_key, self._model)
+
+ session = await self._ensure_session()
+ sentiment_candidates = await sentiment_source.get_candidates()
+ llm_candidates = await llm_source.get_candidates(session=session)
+
+ candidates = self._merge_candidates(sentiment_candidates, llm_candidates)
+ if not candidates:
+ logger.info("No candidates found")
+ return []
+
+ # Stage 2: technical filter
+ filtered = await self._technical_filter(candidates)
+ if not filtered:
+ logger.info("All candidates filtered out by technical criteria")
+ return []
+
+ # Stage 3: LLM final selection
+ selections = await self._llm_final_select(filtered, market_sentiment)
+
+ # Persist and publish
+ today = datetime.now(UTC).date()
+ sentiment_snapshot = {
+ "fear_greed": market_sentiment.fear_greed,
+ "market_regime": market_sentiment.market_regime,
+ "vix": market_sentiment.vix,
+ }
+ for stock in selections:
+ try:
+ await self._db.insert_stock_selection(
+ trade_date=today,
+ symbol=stock.symbol,
+ side=stock.side.value,
+ conviction=stock.conviction,
+ reason=stock.reason,
+ key_news=stock.key_news,
+ sentiment_snapshot=sentiment_snapshot,
+ )
+ except Exception as e:
+ logger.error("Failed to persist selection for %s: %s", stock.symbol, e)
+
+ try:
+ await self._broker.publish(
+ "selected_stocks",
+ {
+ "symbol": stock.symbol,
+ "side": stock.side.value,
+ "conviction": stock.conviction,
+ "reason": stock.reason,
+ "key_news": stock.key_news,
+ "trade_date": str(today),
+ },
+ )
+ except Exception as e:
+ logger.error("Failed to publish selection for %s: %s", stock.symbol, e)
+
+ return selections
+
+ def _merge_candidates(
+ self, sentiment: list[Candidate], llm: list[Candidate]
+ ) -> list[Candidate]:
+ """Deduplicate candidates by symbol, keeping the higher score."""
+ by_symbol: dict[str, Candidate] = {}
+ for c in sentiment + llm:
+ existing = by_symbol.get(c.symbol)
+ if existing is None or c.score > existing.score:
+ by_symbol[c.symbol] = c
+ return sorted(by_symbol.values(), key=lambda c: c.score, reverse=True)
+
+ async def _technical_filter(self, candidates: list[Candidate]) -> list[Candidate]:
+ """Filter candidates using RSI, EMA20, and volume criteria."""
+ passed = []
+ for candidate in candidates:
+ try:
+ bars = await self._alpaca.get_bars(candidate.symbol, timeframe="1Day", limit=60)
+ if len(bars) < 21:
+ logger.debug("Insufficient bars for %s", candidate.symbol)
+ continue
+
+ closes = [float(b["c"]) for b in bars]
+ volumes = [float(b["v"]) for b in bars]
+
+ rsi = _compute_rsi(closes)
+ if not (30 <= rsi <= 70):
+ logger.debug("%s RSI=%.1f outside 30-70", candidate.symbol, rsi)
+ continue
+
+ ema20 = sum(closes[-20:]) / 20 # simple approximation
+ current_price = closes[-1]
+ if current_price <= ema20:
+ logger.debug(
+ "%s price %.2f <= EMA20 %.2f", candidate.symbol, current_price, ema20
+ )
+ continue
+
+ avg_volume = sum(volumes[:-1]) / max(len(volumes) - 1, 1)
+ current_volume = volumes[-1]
+ if current_volume <= 0.5 * avg_volume:
+ logger.debug(
+ "%s volume %.0f <= 50%% avg %.0f",
+ candidate.symbol,
+ current_volume,
+ avg_volume,
+ )
+ continue
+
+ passed.append(candidate)
+ except Exception as e:
+ logger.warning("Technical filter error for %s: %s", candidate.symbol, e)
+
+ return passed
+
+ async def _llm_final_select(
+ self, candidates: list[Candidate], market_sentiment: MarketSentiment
+ ) -> list[SelectedStock]:
+ """Ask Claude to pick 2-3 stocks with rationale."""
+ candidate_lines = [
+ f"- {c.symbol} (source={c.source}, score={c.score:.2f}, reason={c.reason})"
+ for c in candidates
+ ]
+ market_context = (
+ f"Fear/Greed: {market_sentiment.fear_greed} ({market_sentiment.fear_greed_label}), "
+ f"VIX: {market_sentiment.vix}, "
+ f"Fed stance: {market_sentiment.fed_stance}, "
+ f"Regime: {market_sentiment.market_regime}"
+ )
+
+ prompt = (
+ f"You are a portfolio manager. Select 2-3 stocks for today's session.\n\n"
+ f"Market context: {market_context}\n\n"
+ f"Candidates (already passed technical filters):\n"
+ + "\n".join(candidate_lines)
+ + "\n\n"
+ "Return ONLY a JSON array with objects having:\n"
+ " symbol, side ('BUY' or 'SELL'), conviction (0-1), reason (1-2 sentences), "
+ "key_news (list of 1-3 relevant headlines or facts)\n"
+ f"Select at most {self._max_picks} stocks."
+ )
+
+ try:
+ session = await self._ensure_session()
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM final select error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
+
+ content = data.get("content", [])
+ text = ""
+ for block in content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ text += block.get("text", "")
+
+ return _parse_llm_selections(text)[: self._max_picks]
+ except Exception as e:
+ logger.error("LLM final select error: %s", e)
+ return []
diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py
index d5be675..1d9d289 100644
--- a/services/strategy-engine/strategies/base.py
+++ b/services/strategy-engine/strategies/base.py
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from collections import deque
from decimal import Decimal
-from typing import Optional
import pandas as pd
@@ -102,7 +101,7 @@ class BaseStrategy(ABC):
def _calculate_atr_stops(
self, entry_price: Decimal, side: str
- ) -> tuple[Optional[Decimal], Optional[Decimal]]:
+ ) -> tuple[Decimal | None, Decimal | None]:
"""Calculate ATR-based stop-loss and take-profit.
Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data.
@@ -131,7 +130,7 @@ class BaseStrategy(ABC):
return sl, tp
- def _apply_filters(self, signal: Signal) -> Optional[Signal]:
+ def _apply_filters(self, signal: Signal) -> Signal | None:
"""Apply all filters to a signal. Returns signal with SL/TP or None if filtered out."""
if signal is None:
return None
diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py
index ebe7967..02ff09a 100644
--- a/services/strategy-engine/strategies/bollinger_strategy.py
+++ b/services/strategy-engine/strategies/bollinger_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py
index ba92485..f562918 100644
--- a/services/strategy-engine/strategies/combined_strategy.py
+++ b/services/strategy-engine/strategies/combined_strategy.py
@@ -2,7 +2,7 @@
from decimal import Decimal
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py
index 68d0ba3..9c181f3 100644
--- a/services/strategy-engine/strategies/ema_crossover_strategy.py
+++ b/services/strategy-engine/strategies/ema_crossover_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py
index 283bfe5..491252e 100644
--- a/services/strategy-engine/strategies/grid_strategy.py
+++ b/services/strategy-engine/strategies/grid_strategy.py
@@ -1,9 +1,8 @@
from decimal import Decimal
-from typing import Optional
import numpy as np
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -17,7 +16,7 @@ class GridStrategy(BaseStrategy):
self._grid_count: int = 5
self._quantity: Decimal = Decimal("0.01")
self._grid_levels: list[float] = []
- self._last_zone: Optional[int] = None
+ self._last_zone: int | None = None
self._exit_threshold_pct: float = 5.0
self._out_of_range: bool = False
self._in_position: bool = False # Track if we have any grid positions
diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py
index 3c713e6..01637b7 100644
--- a/services/strategy-engine/strategies/indicators/__init__.py
+++ b/services/strategy-engine/strategies/indicators/__init__.py
@@ -1,21 +1,21 @@
"""Reusable technical indicator functions."""
-from strategies.indicators.trend import ema, sma, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
__all__ = [
- "ema",
- "sma",
- "macd",
"adx",
"atr",
"bollinger_bands",
+ "ema",
"keltner_channels",
+ "macd",
+ "obv",
"rsi",
+ "sma",
"stochastic",
- "volume_sma",
"volume_ratio",
- "obv",
+ "volume_sma",
]
diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py
index c479452..a82210b 100644
--- a/services/strategy-engine/strategies/indicators/momentum.py
+++ b/services/strategy-engine/strategies/indicators/momentum.py
@@ -1,7 +1,7 @@
"""Momentum indicators: RSI, Stochastic."""
-import pandas as pd
import numpy as np
+import pandas as pd
def rsi(closes: pd.Series, period: int = 14) -> pd.Series:
diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py
index c94a071..1085199 100644
--- a/services/strategy-engine/strategies/indicators/trend.py
+++ b/services/strategy-engine/strategies/indicators/trend.py
@@ -1,7 +1,7 @@
"""Trend indicators: EMA, SMA, MACD, ADX."""
-import pandas as pd
import numpy as np
+import pandas as pd
def sma(series: pd.Series, period: int) -> pd.Series:
diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py
index c16143e..da82f26 100644
--- a/services/strategy-engine/strategies/indicators/volatility.py
+++ b/services/strategy-engine/strategies/indicators/volatility.py
@@ -1,7 +1,7 @@
"""Volatility indicators: ATR, Bollinger Bands, Keltner Channels."""
-import pandas as pd
import numpy as np
+import pandas as pd
def atr(
diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py
index 502f1ce..d7c6471 100644
--- a/services/strategy-engine/strategies/indicators/volume.py
+++ b/services/strategy-engine/strategies/indicators/volume.py
@@ -1,7 +1,7 @@
"""Volume indicators: Volume SMA, Volume Ratio, OBV."""
-import pandas as pd
import numpy as np
+import pandas as pd
def volume_sma(volumes: pd.Series, period: int = 20) -> pd.Series:
diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py
index 356a42b..b5aea07 100644
--- a/services/strategy-engine/strategies/macd_strategy.py
+++ b/services/strategy-engine/strategies/macd_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py
index 7eaa59e..cbc8440 100644
--- a/services/strategy-engine/strategies/moc_strategy.py
+++ b/services/strategy-engine/strategies/moc_strategy.py
@@ -8,12 +8,12 @@ Rules:
"""
from collections import deque
-from decimal import Decimal
from datetime import datetime
+from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py
index 0646d8c..2df080d 100644
--- a/services/strategy-engine/strategies/rsi_strategy.py
+++ b/services/strategy-engine/strategies/rsi_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py
index ef2ae14..67b5c23 100644
--- a/services/strategy-engine/strategies/volume_profile_strategy.py
+++ b/services/strategy-engine/strategies/volume_profile_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import numpy as np
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -137,7 +137,7 @@ class VolumeProfileStrategy(BaseStrategy):
if result is None:
return None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ poc, va_low, va_high, hvn_levels, _lvn_levels = result
if close < va_low:
self._was_below_va = True
diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py
index d64950e..4ee4952 100644
--- a/services/strategy-engine/strategies/vwap_strategy.py
+++ b/services/strategy-engine/strategies/vwap_strategy.py
@@ -1,7 +1,7 @@
from collections import deque
from decimal import Decimal
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -107,7 +107,7 @@ class VwapStrategy(BaseStrategy):
# Standard deviation of (TP - VWAP) for bands
std_dev = 0.0
if len(self._tp_values) >= 2:
- diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)]
+ diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values, strict=True)]
mean_diff = sum(diffs) / len(diffs)
variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs)
std_dev = variance**0.5
diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py
index eb31b23..2b909ef 100644
--- a/services/strategy-engine/tests/conftest.py
+++ b/services/strategy-engine/tests/conftest.py
@@ -7,3 +7,8 @@ from pathlib import Path
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
if str(STRATEGIES_DIR) not in sys.path:
sys.path.insert(0, str(STRATEGIES_DIR.parent))
+
+# Ensure the worktree's strategy_engine src is preferred over any installed version
+WORKTREE_SRC = Path(__file__).parent.parent / "src"
+if str(WORKTREE_SRC) not in sys.path:
+ sys.path.insert(0, str(WORKTREE_SRC))
diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py
index ae9ca05..66adec7 100644
--- a/services/strategy-engine/tests/test_base_filters.py
+++ b/services/strategy-engine/tests/test_base_filters.py
@@ -5,12 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy
+from shared.models import Candle, OrderSide, Signal
+
class DummyStrategy(BaseStrategy):
name = "dummy"
@@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None):
return Candle(
symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(h)),
low=Decimal(str(lo)),
diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py
index 8261377..70ec66e 100644
--- a/services/strategy-engine/tests/test_bollinger_strategy.py
+++ b/services/strategy-engine/tests/test_bollinger_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Bollinger Bands strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.bollinger_strategy import BollingerStrategy
from shared.models import Candle, OrderSide
-from strategies.bollinger_strategy import BollingerStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py
index 8a4dc74..6a15250 100644
--- a/services/strategy-engine/tests/test_combined_strategy.py
+++ b/services/strategy-engine/tests/test_combined_strategy.py
@@ -5,13 +5,14 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-import pytest
-from shared.models import Candle, Signal, OrderSide
-from strategies.combined_strategy import CombinedStrategy
+import pytest
from strategies.base import BaseStrategy
+from strategies.combined_strategy import CombinedStrategy
+
+from shared.models import Candle, OrderSide, Signal
class AlwaysBuyStrategy(BaseStrategy):
@@ -74,7 +75,7 @@ def _candle(price=100.0):
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price + 10)),
low=Decimal(str(price - 10)),
diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py
index 7028eb0..af2b587 100644
--- a/services/strategy-engine/tests/test_ema_crossover_strategy.py
+++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the EMA Crossover strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from shared.models import Candle, OrderSide
-from strategies.ema_crossover_strategy import EmaCrossoverStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py
index 2623027..fa888b5 100644
--- a/services/strategy-engine/tests/test_engine.py
+++ b/services/strategy-engine/tests/test_engine.py
@@ -1,21 +1,21 @@
"""Tests for the StrategyEngine."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import pytest
+from strategy_engine.engine import StrategyEngine
-from shared.models import Candle, Signal, OrderSide
from shared.events import CandleEvent
-from strategy_engine.engine import StrategyEngine
+from shared.models import Candle, OrderSide, Signal
def make_candle_event() -> dict:
candle = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("50100"),
low=Decimal("49900"),
diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py
index 878b900..f697012 100644
--- a/services/strategy-engine/tests/test_grid_strategy.py
+++ b/services/strategy-engine/tests/test_grid_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Grid strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.grid_strategy import GridStrategy
from shared.models import Candle, OrderSide
-from strategies.grid_strategy import GridStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py
index 481569b..3147fc4 100644
--- a/services/strategy-engine/tests/test_indicators.py
+++ b/services/strategy-engine/tests/test_indicators.py
@@ -5,14 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-import pandas as pd
import numpy as np
+import pandas as pd
import pytest
-
-from strategies.indicators.trend import sma, ema, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
class TestTrend:
diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py
index 556fd4c..7fac16f 100644
--- a/services/strategy-engine/tests/test_macd_strategy.py
+++ b/services/strategy-engine/tests/test_macd_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the MACD strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.macd_strategy import MacdStrategy
from shared.models import Candle, OrderSide
-from strategies.macd_strategy import MacdStrategy
def _candle(price: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price)),
low=Decimal(str(price)),
diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py
index 1928a28..076e846 100644
--- a/services/strategy-engine/tests/test_moc_strategy.py
+++ b/services/strategy-engine/tests/test_moc_strategy.py
@@ -5,19 +5,20 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from shared.models import Candle, OrderSide
from strategies.moc_strategy import MocStrategy
+from shared.models import Candle, OrderSide
+
def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None):
op = open_price if open_price is not None else price - 1 # Default: bullish
return Candle(
symbol="AAPL",
timeframe="5Min",
- open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC),
open=Decimal(str(op)),
high=Decimal(str(price + 1)),
low=Decimal(str(min(op, price) - 1)),
diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py
index 671a9d3..922bfc2 100644
--- a/services/strategy-engine/tests/test_multi_symbol.py
+++ b/services/strategy-engine/tests/test_multi_symbol.py
@@ -9,11 +9,13 @@ import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
+from decimal import Decimal
+
from strategy_engine.engine import StrategyEngine
+
from shared.events import CandleEvent
from shared.models import Candle
-from decimal import Decimal
-from datetime import datetime, timezone
@pytest.mark.asyncio
@@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams():
candle_btc = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
@@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams():
candle_eth = Candle(
symbol="MSFT",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("3000"),
high=Decimal("3100"),
low=Decimal("2900"),
diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py
index 5191fc3..7bd450f 100644
--- a/services/strategy-engine/tests/test_plugin_loader.py
+++ b/services/strategy-engine/tests/test_plugin_loader.py
@@ -2,10 +2,8 @@
from pathlib import Path
-
from strategy_engine.plugin_loader import load_strategies
-
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py
index 6d31fd5..6c74f0b 100644
--- a/services/strategy-engine/tests/test_rsi_strategy.py
+++ b/services/strategy-engine/tests/test_rsi_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the RSI strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.rsi_strategy import RsiStrategy
from shared.models import Candle, OrderSide
-from strategies.rsi_strategy import RsiStrategy
def make_candle(close: float, idx: int = 0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py
new file mode 100644
index 0000000..76b8541
--- /dev/null
+++ b/services/strategy-engine/tests/test_stock_selector.py
@@ -0,0 +1,111 @@
+"""Tests for stock selector engine."""
+
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock
+
+from strategy_engine.stock_selector import (
+ SentimentCandidateSource,
+ StockSelector,
+ _extract_json_array,
+ _parse_llm_selections,
+)
+
+
+async def test_sentiment_candidate_source():
+ mock_db = MagicMock()
+ mock_db.get_top_symbol_scores = AsyncMock(
+ return_value=[
+ {"symbol": "AAPL", "composite": 0.8, "news_count": 5},
+ {"symbol": "NVDA", "composite": 0.6, "news_count": 3},
+ ]
+ )
+
+ source = SentimentCandidateSource(mock_db)
+ candidates = await source.get_candidates()
+
+ assert len(candidates) == 2
+ assert candidates[0].symbol == "AAPL"
+ assert candidates[0].source == "sentiment"
+
+
+def test_parse_llm_selections_valid():
+ llm_response = """
+ [
+ {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]},
+ {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]}
+ ]
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 2
+ assert selections[0].symbol == "NVDA"
+ assert selections[0].conviction == 0.85
+
+
+def test_parse_llm_selections_invalid():
+ selections = _parse_llm_selections("not json")
+ assert selections == []
+
+
+def test_parse_llm_selections_with_markdown():
+ llm_response = """
+ Here are my picks:
+ ```json
+ [
+ {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]}
+ ]
+ ```
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 1
+ assert selections[0].symbol == "TSLA"
+
+
+def test_extract_json_array_from_markdown():
+ text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL", "score": 0.9}]
+
+
+def test_extract_json_array_bare():
+ text = '[{"symbol": "TSLA"}]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "TSLA"}]
+
+
+def test_extract_json_array_invalid():
+ assert _extract_json_array("not json") is None
+
+
+def test_extract_json_array_filters_non_dicts():
+ text = '[{"symbol": "AAPL"}, "bad", 42]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL"}]
+
+
+async def test_selector_close():
+ selector = StockSelector(
+ db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ # No session yet - close should be safe
+ await selector.close()
+ assert selector._http_session is None
+
+
+async def test_selector_blocks_on_risk_off():
+ mock_db = MagicMock()
+ mock_db.get_latest_market_sentiment = AsyncMock(
+ return_value={
+ "fear_greed": 15,
+ "fear_greed_label": "Extreme Fear",
+ "vix": 35.0,
+ "fed_stance": "neutral",
+ "market_regime": "risk_off",
+ "updated_at": datetime.now(UTC),
+ }
+ )
+
+ selector = StockSelector(
+ db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ result = await selector.select()
+ assert result == []
diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py
index debab1f..0d9607a 100644
--- a/services/strategy-engine/tests/test_strategy_validation.py
+++ b/services/strategy-engine/tests/test_strategy_validation.py
@@ -1,13 +1,11 @@
import pytest
-
-from strategies.rsi_strategy import RsiStrategy
-from strategies.macd_strategy import MacdStrategy
from strategies.bollinger_strategy import BollingerStrategy
from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from strategies.grid_strategy import GridStrategy
-from strategies.vwap_strategy import VwapStrategy
+from strategies.macd_strategy import MacdStrategy
+from strategies.rsi_strategy import RsiStrategy
from strategies.volume_profile_strategy import VolumeProfileStrategy
-
+from strategies.vwap_strategy import VwapStrategy
# ── RSI ──────────────────────────────────────────────────────────────────
diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py
index 65ee2e8..f47898c 100644
--- a/services/strategy-engine/tests/test_volume_profile_strategy.py
+++ b/services/strategy-engine/tests/test_volume_profile_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Volume Profile strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.volume_profile_strategy import VolumeProfileStrategy
from shared.models import Candle, OrderSide
-from strategies.volume_profile_strategy import VolumeProfileStrategy
def make_candle(close: float, volume: float = 1.0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
@@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection():
# Create a profile with very high volume at price ~100 and low volume elsewhere
# Prices range from 90 to 110, heavy volume concentrated at 100
- candles_data = []
# Low volume at extremes
- for p in [90, 91, 92, 109, 110]:
- candles_data.append((p, 1.0))
+ candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]]
# Very high volume around 100
- for _ in range(15):
- candles_data.append((100, 100.0))
+ candles_data.extend((100, 100.0) for _ in range(15))
for price, vol in candles_data:
strategy.on_candle(make_candle(price, vol))
@@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection():
# Access the internal method to verify HVN detection
result = strategy._compute_value_area()
assert result is not None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result
# The bin containing price ~100 should have very high volume -> HVN
assert len(hvn_levels) > 0
diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py
index 2c34b01..078d0cf 100644
--- a/services/strategy-engine/tests/test_vwap_strategy.py
+++ b/services/strategy-engine/tests/test_vwap_strategy.py
@@ -1,11 +1,11 @@
"""Tests for the VWAP strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.vwap_strategy import VwapStrategy
from shared.models import Candle, OrderSide
-from strategies.vwap_strategy import VwapStrategy
def make_candle(
@@ -20,7 +20,7 @@ def make_candle(
if low is None:
low = close
if open_time is None:
- open_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ open_time = datetime(2024, 1, 1, tzinfo=UTC)
return Candle(
symbol="AAPL",
timeframe="1m",
@@ -111,11 +111,11 @@ def test_vwap_daily_reset():
"""Candles from two different dates cause VWAP to reset."""
strategy = _configured_strategy()
- day1 = datetime(2024, 1, 1, tzinfo=timezone.utc)
- day2 = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ day1 = datetime(2024, 1, 1, tzinfo=UTC)
+ day2 = datetime(2024, 1, 2, tzinfo=UTC)
# Feed 35 candles on day 1 to build VWAP state
- for i in range(35):
+ for _i in range(35):
strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1))
# Verify state is built up