summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--services/strategy-engine/src/strategy_engine/stock_selector.py403
-rw-r--r--services/strategy-engine/tests/conftest.py5
-rw-r--r--services/strategy-engine/tests/test_stock_selector.py78
3 files changed, 486 insertions, 0 deletions
diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py
new file mode 100644
index 0000000..e1f2fe7
--- /dev/null
+++ b/services/strategy-engine/src/strategy_engine/stock_selector.py
@@ -0,0 +1,403 @@
+"""3-stage stock selector engine: sentiment → technical → LLM."""
+
+import json
+import logging
+import re
+from datetime import datetime, timezone
+from decimal import Decimal
+from typing import Optional
+
+import aiohttp
+import pandas as pd
+
+from shared.alpaca import AlpacaClient
+from shared.broker import RedisBroker
+from shared.db import Database
+from shared.models import OrderSide
+from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore
+
+logger = logging.getLogger(__name__)
+
+ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
+
+
+def _parse_llm_selections(text: str) -> list[SelectedStock]:
+ """Parse LLM response into SelectedStock list.
+
+ Handles both bare JSON arrays and markdown code blocks.
+ Returns empty list on any parse error.
+ """
+ # Try to extract JSON from markdown code block first
+ code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
+ if code_block:
+ raw = code_block.group(1)
+ else:
+ # Try to find a bare JSON array
+ array_match = re.search(r"\[.*\]", text, re.DOTALL)
+ if array_match:
+ raw = array_match.group(0)
+ else:
+ raw = text.strip()
+
+ try:
+ data = json.loads(raw)
+ if not isinstance(data, list):
+ return []
+ selections = []
+ for item in data:
+ if not isinstance(item, dict):
+ continue
+ try:
+ selection = SelectedStock(
+ symbol=item["symbol"],
+ side=OrderSide(item["side"]),
+ conviction=float(item["conviction"]),
+ reason=item.get("reason", ""),
+ key_news=item.get("key_news", []),
+ )
+ selections.append(selection)
+ except (KeyError, ValueError) as e:
+ logger.warning("Skipping invalid selection item: %s", e)
+ return selections
+ except (json.JSONDecodeError, TypeError):
+ return []
+
+
+class SentimentCandidateSource:
+ """Generates candidates from DB sentiment scores."""
+
+ def __init__(self, db: Database) -> None:
+ self._db = db
+
+ async def get_candidates(self) -> list[Candidate]:
+ rows = await self._db.get_top_symbol_scores(limit=20)
+ candidates = []
+ for row in rows:
+ composite = float(row.get("composite", 0))
+ if composite == 0:
+ continue
+ candidates.append(
+ Candidate(
+ symbol=row["symbol"],
+ source="sentiment",
+ score=composite,
+ reason=f"composite={composite:.2f}, news_count={row.get('news_count', 0)}",
+ )
+ )
+ return candidates
+
+
+class LLMCandidateSource:
+ """Generates candidates by asking Claude to analyze recent news."""
+
+ def __init__(self, db: Database, api_key: str, model: str) -> None:
+ self._db = db
+ self._api_key = api_key
+ self._model = model
+
+ async def get_candidates(self) -> list[Candidate]:
+ news_items = await self._db.get_recent_news(hours=24)
+ if not news_items:
+ return []
+
+ headlines = []
+ for item in news_items[:50]: # cap at 50 to stay within context
+ symbols = item.get("symbols", [])
+ sym_str = ", ".join(symbols) if symbols else "N/A"
+ headlines.append(f"[{sym_str}] {item['headline']}")
+
+ prompt = (
+ "You are a stock analyst. Given recent news headlines, identify the 5-10 most "
+ "actionable US stock tickers. Return ONLY a JSON array with objects having: "
+ "symbol (ticker), direction ('BUY' or 'SELL'), score (0-1), reason (brief).\n\n"
+ "Headlines:\n" + "\n".join(headlines)
+ )
+
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM candidate source error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
+
+ content = data.get("content", [])
+ text = ""
+ for block in content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ text += block.get("text", "")
+
+ return self._parse_candidates(text)
+ except Exception as e:
+ logger.error("LLMCandidateSource error: %s", e)
+ return []
+
+ def _parse_candidates(self, text: str) -> list[Candidate]:
+ code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
+ if code_block:
+ raw = code_block.group(1)
+ else:
+ array_match = re.search(r"\[.*\]", text, re.DOTALL)
+ raw = array_match.group(0) if array_match else text.strip()
+
+ try:
+ items = json.loads(raw)
+ if not isinstance(items, list):
+ return []
+ candidates = []
+ for item in items:
+ if not isinstance(item, dict):
+ continue
+ try:
+ direction_str = item.get("direction", "BUY")
+ direction = OrderSide(direction_str)
+ except ValueError:
+ direction = None
+ candidates.append(
+ Candidate(
+ symbol=item["symbol"],
+ source="llm",
+ direction=direction,
+ score=float(item.get("score", 0.5)),
+ reason=item.get("reason", ""),
+ )
+ )
+ return candidates
+ except (json.JSONDecodeError, TypeError, KeyError):
+ return []
+
+
+def _compute_rsi(closes: list[float], period: int = 14) -> float:
+ """Compute RSI for the last data point."""
+ if len(closes) < period + 1:
+ return 50.0 # neutral if insufficient data
+
+ deltas = [closes[i] - closes[i - 1] for i in range(1, len(closes))]
+ gains = [d if d > 0 else 0.0 for d in deltas]
+ losses = [-d if d < 0 else 0.0 for d in deltas]
+
+ avg_gain = sum(gains[:period]) / period
+ avg_loss = sum(losses[:period]) / period
+
+ for i in range(period, len(deltas)):
+ avg_gain = (avg_gain * (period - 1) + gains[i]) / period
+ avg_loss = (avg_loss * (period - 1) + losses[i]) / period
+
+ if avg_loss == 0:
+ return 100.0
+ rs = avg_gain / avg_loss
+ return 100.0 - (100.0 / (1.0 + rs))
+
+
+class StockSelector:
+ """Orchestrates the 3-stage stock selection pipeline."""
+
+ def __init__(
+ self,
+ db: Database,
+ broker: RedisBroker,
+ alpaca: AlpacaClient,
+ anthropic_api_key: str,
+ anthropic_model: str = "claude-sonnet-4-20250514",
+ max_picks: int = 3,
+ ) -> None:
+ self._db = db
+ self._broker = broker
+ self._alpaca = alpaca
+ self._api_key = anthropic_api_key
+ self._model = anthropic_model
+ self._max_picks = max_picks
+
+ async def select(self) -> list[SelectedStock]:
+ """Run the full 3-stage pipeline and return selected stocks."""
+ # Market gate: check sentiment
+ sentiment_data = await self._db.get_latest_market_sentiment()
+ if sentiment_data is None:
+ logger.warning("No market sentiment data; skipping selection")
+ return []
+
+ market_sentiment = MarketSentiment(**sentiment_data)
+ if market_sentiment.market_regime == "risk_off":
+ logger.info("Market is risk_off; skipping stock selection")
+ return []
+
+ # Stage 1: gather candidates from both sources
+ sentiment_source = SentimentCandidateSource(self._db)
+ llm_source = LLMCandidateSource(self._db, self._api_key, self._model)
+
+ sentiment_candidates = await sentiment_source.get_candidates()
+ llm_candidates = await llm_source.get_candidates()
+
+ candidates = self._merge_candidates(sentiment_candidates, llm_candidates)
+ if not candidates:
+ logger.info("No candidates found")
+ return []
+
+ # Stage 2: technical filter
+ filtered = await self._technical_filter(candidates)
+ if not filtered:
+ logger.info("All candidates filtered out by technical criteria")
+ return []
+
+ # Stage 3: LLM final selection
+ selections = await self._llm_final_select(filtered, market_sentiment)
+
+ # Persist and publish
+ today = datetime.now(timezone.utc).date()
+ sentiment_snapshot = {
+ "fear_greed": market_sentiment.fear_greed,
+ "market_regime": market_sentiment.market_regime,
+ "vix": market_sentiment.vix,
+ }
+ for stock in selections:
+ try:
+ await self._db.insert_stock_selection(
+ trade_date=today,
+ symbol=stock.symbol,
+ side=stock.side.value,
+ conviction=stock.conviction,
+ reason=stock.reason,
+ key_news=stock.key_news,
+ sentiment_snapshot=sentiment_snapshot,
+ )
+ except Exception as e:
+ logger.error("Failed to persist selection for %s: %s", stock.symbol, e)
+
+ try:
+ await self._broker.publish(
+ "selected_stocks",
+ {
+ "symbol": stock.symbol,
+ "side": stock.side.value,
+ "conviction": stock.conviction,
+ "reason": stock.reason,
+ "key_news": stock.key_news,
+ "trade_date": str(today),
+ },
+ )
+ except Exception as e:
+ logger.error("Failed to publish selection for %s: %s", stock.symbol, e)
+
+ return selections
+
+ def _merge_candidates(
+ self, sentiment: list[Candidate], llm: list[Candidate]
+ ) -> list[Candidate]:
+ """Deduplicate candidates by symbol, keeping the higher score."""
+ by_symbol: dict[str, Candidate] = {}
+ for c in sentiment + llm:
+ existing = by_symbol.get(c.symbol)
+ if existing is None or c.score > existing.score:
+ by_symbol[c.symbol] = c
+ return sorted(by_symbol.values(), key=lambda c: c.score, reverse=True)
+
+ async def _technical_filter(self, candidates: list[Candidate]) -> list[Candidate]:
+ """Filter candidates using RSI, EMA20, and volume criteria."""
+ passed = []
+ for candidate in candidates:
+ try:
+ bars = await self._alpaca.get_bars(candidate.symbol, timeframe="1Day", limit=60)
+ if len(bars) < 21:
+ logger.debug("Insufficient bars for %s", candidate.symbol)
+ continue
+
+ closes = [float(b["c"]) for b in bars]
+ volumes = [float(b["v"]) for b in bars]
+
+ rsi = _compute_rsi(closes)
+ if not (30 <= rsi <= 70):
+ logger.debug("%s RSI=%.1f outside 30-70", candidate.symbol, rsi)
+ continue
+
+ ema20 = sum(closes[-20:]) / 20 # simple approximation
+ current_price = closes[-1]
+ if current_price <= ema20:
+ logger.debug("%s price %.2f <= EMA20 %.2f", candidate.symbol, current_price, ema20)
+ continue
+
+ avg_volume = sum(volumes[:-1]) / max(len(volumes) - 1, 1)
+ current_volume = volumes[-1]
+ if current_volume <= 0.5 * avg_volume:
+ logger.debug(
+ "%s volume %.0f <= 50%% avg %.0f",
+ candidate.symbol, current_volume, avg_volume,
+ )
+ continue
+
+ passed.append(candidate)
+ except Exception as e:
+ logger.warning("Technical filter error for %s: %s", candidate.symbol, e)
+
+ return passed
+
+ async def _llm_final_select(
+ self, candidates: list[Candidate], market_sentiment: MarketSentiment
+ ) -> list[SelectedStock]:
+ """Ask Claude to pick 2-3 stocks with rationale."""
+ candidate_lines = [
+ f"- {c.symbol} (source={c.source}, score={c.score:.2f}, reason={c.reason})"
+ for c in candidates
+ ]
+ market_context = (
+ f"Fear/Greed: {market_sentiment.fear_greed} ({market_sentiment.fear_greed_label}), "
+ f"VIX: {market_sentiment.vix}, "
+ f"Fed stance: {market_sentiment.fed_stance}, "
+ f"Regime: {market_sentiment.market_regime}"
+ )
+
+ prompt = (
+ f"You are a portfolio manager. Select 2-3 stocks for today's session.\n\n"
+ f"Market context: {market_context}\n\n"
+ f"Candidates (already passed technical filters):\n"
+ + "\n".join(candidate_lines)
+ + "\n\n"
+ "Return ONLY a JSON array with objects having:\n"
+ " symbol, side ('BUY' or 'SELL'), conviction (0-1), reason (1-2 sentences), "
+ "key_news (list of 1-3 relevant headlines or facts)\n"
+ f"Select at most {self._max_picks} stocks."
+ )
+
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM final select error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
+
+ content = data.get("content", [])
+ text = ""
+ for block in content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ text += block.get("text", "")
+
+ return _parse_llm_selections(text)[: self._max_picks]
+ except Exception as e:
+ logger.error("LLM final select error: %s", e)
+ return []
diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py
index eb31b23..2b909ef 100644
--- a/services/strategy-engine/tests/conftest.py
+++ b/services/strategy-engine/tests/conftest.py
@@ -7,3 +7,8 @@ from pathlib import Path
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
if str(STRATEGIES_DIR) not in sys.path:
sys.path.insert(0, str(STRATEGIES_DIR.parent))
+
+# Ensure the worktree's strategy_engine src is preferred over any installed version
+WORKTREE_SRC = Path(__file__).parent.parent / "src"
+if str(WORKTREE_SRC) not in sys.path:
+ sys.path.insert(0, str(WORKTREE_SRC))
diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py
new file mode 100644
index 0000000..a2f5bca
--- /dev/null
+++ b/services/strategy-engine/tests/test_stock_selector.py
@@ -0,0 +1,78 @@
+"""Tests for stock selector engine."""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+from datetime import datetime, timezone
+from decimal import Decimal
+
+from shared.models import OrderSide
+from shared.sentiment_models import SymbolScore, MarketSentiment, SelectedStock, Candidate
+
+from strategy_engine.stock_selector import (
+ SentimentCandidateSource,
+ StockSelector,
+ _parse_llm_selections,
+)
+
+
+async def test_sentiment_candidate_source():
+ mock_db = MagicMock()
+ mock_db.get_top_symbol_scores = AsyncMock(return_value=[
+ {"symbol": "AAPL", "composite": 0.8, "news_count": 5},
+ {"symbol": "NVDA", "composite": 0.6, "news_count": 3},
+ ])
+
+ source = SentimentCandidateSource(mock_db)
+ candidates = await source.get_candidates()
+
+ assert len(candidates) == 2
+ assert candidates[0].symbol == "AAPL"
+ assert candidates[0].source == "sentiment"
+
+
+def test_parse_llm_selections_valid():
+ llm_response = """
+ [
+ {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]},
+ {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]}
+ ]
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 2
+ assert selections[0].symbol == "NVDA"
+ assert selections[0].conviction == 0.85
+
+
+def test_parse_llm_selections_invalid():
+ selections = _parse_llm_selections("not json")
+ assert selections == []
+
+
+def test_parse_llm_selections_with_markdown():
+ llm_response = """
+ Here are my picks:
+ ```json
+ [
+ {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]}
+ ]
+ ```
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 1
+ assert selections[0].symbol == "TSLA"
+
+
+async def test_selector_blocks_on_risk_off():
+ mock_db = MagicMock()
+ mock_db.get_latest_market_sentiment = AsyncMock(return_value={
+ "fear_greed": 15,
+ "fear_greed_label": "Extreme Fear",
+ "vix": 35.0,
+ "fed_stance": "neutral",
+ "market_regime": "risk_off",
+ "updated_at": datetime.now(timezone.utc),
+ })
+
+ selector = StockSelector(db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test")
+ result = await selector.select()
+ assert result == []