diff options
Diffstat (limited to 'services')
28 files changed, 1875 insertions, 0 deletions
diff --git a/services/news-collector/Dockerfile b/services/news-collector/Dockerfile new file mode 100644 index 0000000..a8e5902 --- /dev/null +++ b/services/news-collector/Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/news-collector/ services/news-collector/ +RUN pip install --no-cache-dir ./services/news-collector +RUN python -c "import nltk; nltk.download('vader_lexicon', quiet=True)" +ENV PYTHONPATH=/app +CMD ["python", "-m", "news_collector.main"] diff --git a/services/news-collector/pyproject.toml b/services/news-collector/pyproject.toml new file mode 100644 index 0000000..14c856a --- /dev/null +++ b/services/news-collector/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "news-collector" +version = "0.1.0" +description = "News and sentiment data collector service" +requires-python = ">=3.12" +dependencies = [ + "trading-shared", + "feedparser>=6.0", + "nltk>=3.8", + "aiohttp>=3.9", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "aioresponses>=0.7", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/news_collector"] diff --git a/services/news-collector/src/news_collector/__init__.py b/services/news-collector/src/news_collector/__init__.py new file mode 100644 index 0000000..5547af2 --- /dev/null +++ b/services/news-collector/src/news_collector/__init__.py @@ -0,0 +1 @@ +"""News collector service.""" diff --git a/services/news-collector/src/news_collector/collectors/__init__.py b/services/news-collector/src/news_collector/collectors/__init__.py new file mode 100644 index 0000000..5ef36a7 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/__init__.py @@ -0,0 +1 @@ +"""News collectors.""" diff --git a/services/news-collector/src/news_collector/collectors/base.py b/services/news-collector/src/news_collector/collectors/base.py new file mode 100644 index 0000000..bb43fd6 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/base.py @@ -0,0 +1,18 @@ +"""Base class for all news collectors.""" + +from abc import ABC, abstractmethod + +from shared.models import NewsItem + + +class BaseCollector(ABC): + name: str = "base" + poll_interval: int = 300 # seconds + + @abstractmethod + async def collect(self) -> list[NewsItem]: + """Collect news items from the source.""" + + @abstractmethod + async def is_available(self) -> bool: + """Check if this data source is accessible.""" diff --git a/services/news-collector/src/news_collector/collectors/fear_greed.py b/services/news-collector/src/news_collector/collectors/fear_greed.py new file mode 100644 index 0000000..f79f716 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/fear_greed.py @@ -0,0 +1,63 @@ +"""CNN Fear & Greed Index collector.""" + +import logging +from dataclasses import dataclass +from typing import Optional + +import aiohttp + +from news_collector.collectors.base import BaseCollector + +logger = logging.getLogger(__name__) + +FEAR_GREED_URL = "https://production.dataviz.cnn.io/index/fearandgreed/graphdata" + + +@dataclass +class FearGreedResult: + fear_greed: int + fear_greed_label: str + + +class FearGreedCollector(BaseCollector): + name = "fear_greed" + poll_interval = 3600 # 1 hour + + async def is_available(self) -> bool: + return True + + async def _fetch_index(self) -> Optional[dict]: + headers = {"User-Agent": "Mozilla/5.0"} + try: + async with aiohttp.ClientSession() as session: + async with session.get( + FEAR_GREED_URL, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status != 200: + return None + return await resp.json() + except Exception: + return None + + def _classify(self, score: int) -> str: + if score <= 20: + return "Extreme Fear" + if score <= 40: + return "Fear" + if score <= 60: + return "Neutral" + if score <= 80: + return "Greed" + return "Extreme Greed" + + async def collect(self) -> Optional[FearGreedResult]: + data = await self._fetch_index() + if data is None: + return None + try: + fg = data["fear_and_greed"] + score = int(fg["score"]) + label = fg.get("rating", self._classify(score)) + return FearGreedResult(fear_greed=score, fear_greed_label=label) + except (KeyError, ValueError, TypeError): + return None diff --git a/services/news-collector/src/news_collector/collectors/fed.py b/services/news-collector/src/news_collector/collectors/fed.py new file mode 100644 index 0000000..fce4842 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/fed.py @@ -0,0 +1,119 @@ +"""Federal Reserve RSS collector with hawkish/dovish/neutral stance detection.""" + +import asyncio +import logging +from calendar import timegm +from datetime import datetime, timezone + +import feedparser +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_FED_RSS_URL = "https://www.federalreserve.gov/feeds/press_all.xml" + +_HAWKISH_KEYWORDS = [ + "rate hike", + "interest rate increase", + "tighten", + "tightening", + "inflation", + "hawkish", + "restrictive", + "raise rates", + "hike rates", +] +_DOVISH_KEYWORDS = [ + "rate cut", + "interest rate decrease", + "easing", + "ease", + "stimulus", + "dovish", + "accommodative", + "lower rates", + "cut rates", + "quantitative easing", +] + + +def _detect_stance(text: str) -> str: + lower = text.lower() + hawkish_hits = sum(1 for kw in _HAWKISH_KEYWORDS if kw in lower) + dovish_hits = sum(1 for kw in _DOVISH_KEYWORDS if kw in lower) + if hawkish_hits > dovish_hits: + return "hawkish" + if dovish_hits > hawkish_hits: + return "dovish" + return "neutral" + + +class FedCollector(BaseCollector): + name: str = "fed" + poll_interval: int = 3600 + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_fed_rss(self) -> list[dict]: + loop = asyncio.get_event_loop() + try: + parsed = await loop.run_in_executor(None, feedparser.parse, _FED_RSS_URL) + return parsed.get("entries", []) + except Exception as exc: + logger.error("Fed RSS fetch failed: %s", exc) + return [] + + def _parse_published(self, entry: dict) -> datetime: + published_parsed = entry.get("published_parsed") + if published_parsed: + try: + ts = timegm(published_parsed) + return datetime.fromtimestamp(ts, tz=timezone.utc) + except Exception: + pass + return datetime.now(timezone.utc) + + async def collect(self) -> list[NewsItem]: + try: + entries = await self._fetch_fed_rss() + except Exception as exc: + logger.error("Fed collector error: %s", exc) + return [] + + items: list[NewsItem] = [] + + for entry in entries: + title = entry.get("title", "").strip() + if not title: + continue + + summary = entry.get("summary", "") or "" + combined = f"{title} {summary}" + + sentiment = self._vader.polarity_scores(combined)["compound"] + stance = _detect_stance(combined) + published_at = self._parse_published(entry) + + items.append( + NewsItem( + source=self.name, + headline=title, + summary=summary or None, + url=entry.get("link") or None, + published_at=published_at, + symbols=[], + sentiment=sentiment, + category=NewsCategory.FED, + raw_data={"stance": stance, **dict(entry)}, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/finnhub.py b/services/news-collector/src/news_collector/collectors/finnhub.py new file mode 100644 index 0000000..13e3602 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/finnhub.py @@ -0,0 +1,88 @@ +"""Finnhub news collector with VADER sentiment analysis.""" + +import logging +from datetime import datetime, timezone + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_CATEGORY_KEYWORDS: dict[NewsCategory, list[str]] = { + NewsCategory.FED: ["fed", "fomc", "rate", "federal reserve"], + NewsCategory.POLICY: ["tariff", "trump", "regulation", "policy", "trade war"], + NewsCategory.EARNINGS: ["earnings", "revenue", "profit", "eps", "guidance", "quarter"], +} + + +def _categorize(text: str) -> NewsCategory: + lower = text.lower() + for category, keywords in _CATEGORY_KEYWORDS.items(): + if any(kw in lower for kw in keywords): + return category + return NewsCategory.MACRO + + +class FinnhubCollector(BaseCollector): + name: str = "finnhub" + poll_interval: int = 300 + + _BASE_URL = "https://finnhub.io/api/v1/news" + + def __init__(self, api_key: str) -> None: + self._api_key = api_key + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return bool(self._api_key) + + async def _fetch_news(self) -> list[dict]: + url = f"{self._BASE_URL}?category=general&token={self._api_key}" + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + return await resp.json() + + async def collect(self) -> list[NewsItem]: + try: + raw_items = await self._fetch_news() + except Exception as exc: + logger.error("Finnhub fetch failed: %s", exc) + return [] + + items: list[NewsItem] = [] + for article in raw_items: + headline = article.get("headline", "") + summary = article.get("summary", "") + combined = f"{headline} {summary}" + + sentiment_scores = self._vader.polarity_scores(combined) + sentiment = sentiment_scores["compound"] + + ts = article.get("datetime", 0) + published_at = datetime.fromtimestamp(ts, tz=timezone.utc) + + related = article.get("related", "") + symbols = [t.strip() for t in related.split(",") if t.strip()] if related else [] + + category = _categorize(combined) + + items.append( + NewsItem( + source=self.name, + headline=headline, + summary=summary or None, + url=article.get("url") or None, + published_at=published_at, + symbols=symbols, + sentiment=sentiment, + category=category, + raw_data=article, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/reddit.py b/services/news-collector/src/news_collector/collectors/reddit.py new file mode 100644 index 0000000..226a2f9 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/reddit.py @@ -0,0 +1,97 @@ +"""Reddit social sentiment collector using JSON API with VADER sentiment analysis.""" + +import logging +import re +from datetime import datetime, timezone + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_SUBREDDITS = ["wallstreetbets", "stocks", "investing"] +_MIN_SCORE = 50 + +_TICKER_PATTERN = re.compile( + r"\b(AAPL|MSFT|GOOGL|GOOG|AMZN|TSLA|NVDA|META|BRK\.?[AB]|JPM|V|UNH|XOM|" + r"JNJ|WMT|MA|PG|HD|CVX|MRK|LLY|ABBV|PFE|BAC|KO|AVGO|COST|MCD|TMO|" + r"CSCO|ACN|ABT|DHR|TXN|NEE|NFLX|PM|UPS|RTX|HON|QCOM|AMGN|LOW|IBM|" + r"INTC|AMD|PYPL|GS|MS|BLK|SPGI|CAT|DE|GE|MMM|BA|F|GM|DIS|CMCSA)\b" +) + + +class RedditCollector(BaseCollector): + name: str = "reddit" + poll_interval: int = 900 + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_subreddit(self, subreddit: str) -> list[dict]: + url = f"https://www.reddit.com/r/{subreddit}/hot.json?limit=25" + headers = {"User-Agent": "TradingPlatform/1.0 (research@example.com)"} + try: + async with aiohttp.ClientSession() as session: + async with session.get( + url, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + return data.get("data", {}).get("children", []) + except Exception as exc: + logger.error("Reddit fetch failed for r/%s: %s", subreddit, exc) + return [] + + async def collect(self) -> list[NewsItem]: + seen_titles: set[str] = set() + items: list[NewsItem] = [] + + for subreddit in _SUBREDDITS: + try: + posts = await self._fetch_subreddit(subreddit) + except Exception as exc: + logger.error("Reddit collector error for r/%s: %s", subreddit, exc) + continue + + for post in posts: + post_data = post.get("data", {}) + title = post_data.get("title", "").strip() + score = post_data.get("score", 0) + + if not title or score < _MIN_SCORE: + continue + if title in seen_titles: + continue + seen_titles.add(title) + + selftext = post_data.get("selftext", "") or "" + combined = f"{title} {selftext}" + + sentiment = self._vader.polarity_scores(combined)["compound"] + symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined))) + + created_utc = post_data.get("created_utc", 0) + published_at = datetime.fromtimestamp(created_utc, tz=timezone.utc) + + items.append( + NewsItem( + source=self.name, + headline=title, + summary=selftext or None, + url=post_data.get("url") or None, + published_at=published_at, + symbols=symbols, + sentiment=sentiment, + category=NewsCategory.SOCIAL, + raw_data=post_data, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/rss.py b/services/news-collector/src/news_collector/collectors/rss.py new file mode 100644 index 0000000..ddf8503 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/rss.py @@ -0,0 +1,105 @@ +"""RSS news collector using feedparser with VADER sentiment analysis.""" + +import asyncio +import logging +import re +from datetime import datetime, timezone +from time import mktime + +import feedparser +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_DEFAULT_FEEDS = [ + "https://finance.yahoo.com/news/rssindex", + "https://news.google.com/rss/search?q=stock+market+finance&hl=en-US&gl=US&ceid=US:en", + "https://feeds.marketwatch.com/marketwatch/topstories/", +] + +_TICKER_PATTERN = re.compile( + r"\b(AAPL|MSFT|GOOGL|GOOG|AMZN|TSLA|NVDA|META|BRK\.?[AB]|JPM|V|UNH|XOM|" + r"JNJ|WMT|MA|PG|HD|CVX|MRK|LLY|ABBV|PFE|BAC|KO|AVGO|COST|MCD|TMO|" + r"CSCO|ACN|ABT|DHR|TXN|NEE|NFLX|PM|UPS|RTX|HON|QCOM|AMGN|LOW|IBM|" + r"INTC|AMD|PYPL|GS|MS|BLK|SPGI|CAT|DE|GE|MMM|BA|F|GM|DIS|CMCSA)\b" +) + + +class RSSCollector(BaseCollector): + name: str = "rss" + poll_interval: int = 600 + + def __init__(self, feeds: list[str] | None = None) -> None: + self._feeds = feeds if feeds is not None else _DEFAULT_FEEDS + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_feeds(self) -> list[dict]: + loop = asyncio.get_event_loop() + results = [] + for url in self._feeds: + try: + parsed = await loop.run_in_executor(None, feedparser.parse, url) + results.append(parsed) + except Exception as exc: + logger.error("RSS fetch failed for %s: %s", url, exc) + return results + + def _parse_published(self, entry: dict) -> datetime: + parsed_time = entry.get("published_parsed") + if parsed_time: + try: + ts = mktime(parsed_time) + return datetime.fromtimestamp(ts, tz=timezone.utc) + except Exception: + pass + return datetime.now(timezone.utc) + + async def collect(self) -> list[NewsItem]: + try: + feeds = await self._fetch_feeds() + except Exception as exc: + logger.error("RSS collector error: %s", exc) + return [] + + seen_titles: set[str] = set() + items: list[NewsItem] = [] + + for feed in feeds: + for entry in feed.get("entries", []): + title = entry.get("title", "").strip() + if not title or title in seen_titles: + continue + seen_titles.add(title) + + summary = entry.get("summary", "") or "" + combined = f"{title} {summary}" + + sentiment_scores = self._vader.polarity_scores(combined) + sentiment = sentiment_scores["compound"] + + symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined))) + + published_at = self._parse_published(entry) + + items.append( + NewsItem( + source=self.name, + headline=title, + summary=summary or None, + url=entry.get("link") or None, + published_at=published_at, + symbols=symbols, + sentiment=sentiment, + category=NewsCategory.MACRO, + raw_data=dict(entry), + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/sec_edgar.py b/services/news-collector/src/news_collector/collectors/sec_edgar.py new file mode 100644 index 0000000..ca1d070 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/sec_edgar.py @@ -0,0 +1,100 @@ +"""SEC EDGAR filing collector (free, no API key required).""" + +import logging +from datetime import datetime, timezone + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem +from news_collector.collectors.base import BaseCollector + +logger = logging.getLogger(__name__) + +TRACKED_CIKS = { + "0000320193": "AAPL", + "0000789019": "MSFT", + "0001652044": "GOOGL", + "0001018724": "AMZN", + "0001318605": "TSLA", + "0001045810": "NVDA", + "0001326801": "META", + "0000019617": "JPM", + "0000078003": "PFE", + "0000021344": "KO", +} + +SEC_USER_AGENT = "TradingPlatform research@example.com" + + +class SecEdgarCollector(BaseCollector): + name = "sec_edgar" + poll_interval = 1800 # 30 minutes + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_recent_filings(self) -> list[dict]: + results = [] + headers = {"User-Agent": SEC_USER_AGENT} + async with aiohttp.ClientSession() as session: + for cik, ticker in TRACKED_CIKS.items(): + try: + url = f"https://data.sec.gov/submissions/CIK{cik}.json" + async with session.get( + url, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + data["tickers"] = [{"ticker": ticker}] + results.append(data) + except Exception as exc: + logger.warning("sec_fetch_failed", cik=cik, error=str(exc)) + return results + + async def collect(self) -> list[NewsItem]: + filings_data = await self._fetch_recent_filings() + items = [] + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + for company_data in filings_data: + tickers = [t["ticker"] for t in company_data.get("tickers", [])] + company_name = company_data.get("name", "Unknown") + recent = company_data.get("filings", {}).get("recent", {}) + + forms = recent.get("form", []) + dates = recent.get("filingDate", []) + descriptions = recent.get("primaryDocDescription", []) + accessions = recent.get("accessionNumber", []) + + for i, form in enumerate(forms): + if form != "8-K": + continue + filing_date = dates[i] if i < len(dates) else "" + if filing_date != today: + continue + + desc = descriptions[i] if i < len(descriptions) else "8-K Filing" + accession = accessions[i] if i < len(accessions) else "" + headline = f"{company_name} ({', '.join(tickers)}): {form} - {desc}" + + items.append( + NewsItem( + source=self.name, + headline=headline, + summary=desc, + url=f"https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany&accession={accession}", + published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace( + tzinfo=timezone.utc + ), + symbols=tickers, + sentiment=self._vader.polarity_scores(headline)["compound"], + category=NewsCategory.FILING, + raw_data={"form": form, "accession": accession}, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/truth_social.py b/services/news-collector/src/news_collector/collectors/truth_social.py new file mode 100644 index 0000000..33ebc86 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/truth_social.py @@ -0,0 +1,86 @@ +"""Truth Social collector using Mastodon-compatible API with VADER sentiment analysis.""" + +import logging +import re +from datetime import datetime, timezone + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_TRUMP_ACCOUNT_ID = "107780257626128497" +_API_URL = f"https://truthsocial.com/api/v1/accounts/{_TRUMP_ACCOUNT_ID}/statuses" + +_HTML_TAG_PATTERN = re.compile(r"<[^>]+>") + + +def _strip_html(text: str) -> str: + return _HTML_TAG_PATTERN.sub("", text).strip() + + +class TruthSocialCollector(BaseCollector): + name: str = "truth_social" + poll_interval: int = 900 + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_posts(self) -> list[dict]: + headers = {"User-Agent": "TradingPlatform/1.0 (research@example.com)"} + try: + async with aiohttp.ClientSession() as session: + async with session.get( + _API_URL, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + return await resp.json() + except Exception as exc: + logger.error("Truth Social fetch failed: %s", exc) + return [] + + async def collect(self) -> list[NewsItem]: + try: + posts = await self._fetch_posts() + except Exception as exc: + logger.error("Truth Social collector error: %s", exc) + return [] + + items: list[NewsItem] = [] + + for post in posts: + raw_content = post.get("content", "") or "" + content = _strip_html(raw_content) + if not content: + continue + + sentiment = self._vader.polarity_scores(content)["compound"] + + created_at_str = post.get("created_at", "") + try: + published_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + except Exception: + published_at = datetime.now(timezone.utc) + + items.append( + NewsItem( + source=self.name, + headline=content[:200], + summary=content if len(content) > 200 else None, + url=post.get("url") or None, + published_at=published_at, + symbols=[], + sentiment=sentiment, + category=NewsCategory.POLICY, + raw_data=post, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/config.py b/services/news-collector/src/news_collector/config.py new file mode 100644 index 0000000..70d98f1 --- /dev/null +++ b/services/news-collector/src/news_collector/config.py @@ -0,0 +1,10 @@ +"""News Collector configuration.""" + +from shared.config import Settings + + +class NewsCollectorConfig(Settings): + health_port: int = 8084 + finnhub_api_key: str = "" + news_poll_interval: int = 300 + sentiment_aggregate_interval: int = 900 diff --git a/services/news-collector/src/news_collector/main.py b/services/news-collector/src/news_collector/main.py new file mode 100644 index 0000000..3493f7c --- /dev/null +++ b/services/news-collector/src/news_collector/main.py @@ -0,0 +1,193 @@ +"""News Collector Service — fetches news from multiple sources and aggregates sentiment.""" + +import asyncio +from datetime import datetime, timezone + +from shared.broker import RedisBroker +from shared.db import Database +from shared.events import NewsEvent +from shared.healthcheck import HealthCheckServer +from shared.logging import setup_logging +from shared.metrics import ServiceMetrics +from shared.models import NewsItem +from shared.notifier import TelegramNotifier +from shared.sentiment_models import MarketSentiment +from shared.sentiment import SentimentAggregator + +from news_collector.config import NewsCollectorConfig +from news_collector.collectors.finnhub import FinnhubCollector +from news_collector.collectors.rss import RSSCollector +from news_collector.collectors.sec_edgar import SecEdgarCollector +from news_collector.collectors.truth_social import TruthSocialCollector +from news_collector.collectors.reddit import RedditCollector +from news_collector.collectors.fear_greed import FearGreedCollector +from news_collector.collectors.fed import FedCollector + +# Health check port: base + 4 +HEALTH_PORT_OFFSET = 4 + + +async def run_collector_once(collector, db: Database, broker: RedisBroker) -> int: + """Run a single collector, store results in DB, publish to Redis. + + Returns the number of items collected. + """ + items: list[NewsItem] = await collector.collect() + count = 0 + for item in items: + await db.insert_news_item(item) + event = NewsEvent(data=item) + stream = f"news.{item.category.value}" + await broker.publish(stream, event.to_dict()) + count += 1 + return count + + +async def run_collector_loop(collector, db: Database, broker: RedisBroker, log) -> None: + """Run a collector repeatedly on its configured poll_interval.""" + while True: + try: + count = await run_collector_once(collector, db, broker) + log.info( + "collector_ran", + collector=collector.name, + count=count, + ) + except Exception as exc: + log.warning( + "collector_error", + collector=collector.name, + error=str(exc), + ) + await asyncio.sleep(collector.poll_interval) + + +async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log) -> None: + """Fetch Fear & Greed index on its interval and update MarketSentiment in DB.""" + while True: + try: + result = await collector.collect() + if result is not None: + ms = MarketSentiment( + fear_greed=result.fear_greed, + fear_greed_label=result.fear_greed_label, + vix=None, + fed_stance="neutral", + market_regime=_determine_regime(result.fear_greed, None), + updated_at=datetime.now(timezone.utc), + ) + await db.upsert_market_sentiment(ms) + log.info( + "fear_greed_updated", + value=result.fear_greed, + label=result.fear_greed_label, + ) + except Exception as exc: + log.warning("fear_greed_error", error=str(exc)) + await asyncio.sleep(collector.poll_interval) + + +async def run_aggregator_loop(db: Database, interval: int, log) -> None: + """Run SentimentAggregator every interval seconds and persist scores.""" + aggregator = SentimentAggregator() + while True: + await asyncio.sleep(interval) + try: + now = datetime.now(timezone.utc) + news_items = await db.get_recent_news(hours=24) + scores = aggregator.aggregate(news_items, now) + for score in scores.values(): + await db.upsert_symbol_score(score) + log.info("aggregation_complete", symbols=len(scores)) + except Exception as exc: + log.warning("aggregator_error", error=str(exc)) + + +def _determine_regime(fear_greed: int, vix: float | None) -> str: + """Classify market regime from fear/greed index and optional VIX.""" + aggregator = SentimentAggregator() + return aggregator.determine_regime(fear_greed, vix) + + +async def run() -> None: + config = NewsCollectorConfig() + log = setup_logging("news-collector", config.log_level, config.log_format) + metrics = ServiceMetrics("news_collector") + + notifier = TelegramNotifier( + bot_token=config.telegram_bot_token, + chat_id=config.telegram_chat_id, + ) + + db = Database(config.database_url) + await db.connect() + + broker = RedisBroker(config.redis_url) + + health = HealthCheckServer( + "news-collector", + port=config.health_port, + auth_token=config.metrics_auth_token, + ) + await health.start() + metrics.service_up.labels(service="news-collector").set(1) + + # Build collectors + finnhub = FinnhubCollector(api_key=config.finnhub_api_key) + rss = RSSCollector() + sec = SecEdgarCollector() + truth = TruthSocialCollector() + reddit = RedditCollector() + fear_greed = FearGreedCollector() + fed = FedCollector() + + news_collectors = [finnhub, rss, sec, truth, reddit, fed] + + log.info( + "starting", + collectors=[c.name for c in news_collectors], + poll_interval=config.news_poll_interval, + aggregate_interval=config.sentiment_aggregate_interval, + ) + + try: + tasks = [] + for collector in news_collectors: + tasks.append( + asyncio.create_task( + run_collector_loop(collector, db, broker, log), + name=f"collector-{collector.name}", + ) + ) + tasks.append( + asyncio.create_task( + run_fear_greed_loop(fear_greed, db, log), + name="fear-greed-loop", + ) + ) + tasks.append( + asyncio.create_task( + run_aggregator_loop(db, config.sentiment_aggregate_interval, log), + name="aggregator-loop", + ) + ) + await asyncio.gather(*tasks) + except Exception as exc: + log.error("fatal_error", error=str(exc)) + await notifier.send_error(str(exc), "news-collector") + raise + finally: + metrics.service_up.labels(service="news-collector").set(0) + for task in tasks: + task.cancel() + await notifier.close() + await broker.close() + await db.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/news-collector/tests/__init__.py b/services/news-collector/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/news-collector/tests/__init__.py diff --git a/services/news-collector/tests/test_fear_greed.py b/services/news-collector/tests/test_fear_greed.py new file mode 100644 index 0000000..d483aa6 --- /dev/null +++ b/services/news-collector/tests/test_fear_greed.py @@ -0,0 +1,49 @@ +"""Tests for CNN Fear & Greed Index collector.""" + +import pytest +from unittest.mock import AsyncMock, patch + +from news_collector.collectors.fear_greed import FearGreedCollector + + +@pytest.fixture +def collector(): + return FearGreedCollector() + + +def test_collector_name(collector): + assert collector.name == "fear_greed" + assert collector.poll_interval == 3600 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_api_response(collector): + mock_data = { + "fear_and_greed": { + "score": 45.0, + "rating": "Fear", + "timestamp": "2026-04-02T12:00:00+00:00", + } + } + with patch.object(collector, "_fetch_index", new_callable=AsyncMock, return_value=mock_data): + result = await collector.collect() + assert result.fear_greed == 45 + assert result.fear_greed_label == "Fear" + + +async def test_collect_returns_none_on_failure(collector): + with patch.object(collector, "_fetch_index", new_callable=AsyncMock, return_value=None): + result = await collector.collect() + assert result is None + + +def test_classify_label(): + c = FearGreedCollector() + assert c._classify(10) == "Extreme Fear" + assert c._classify(30) == "Fear" + assert c._classify(50) == "Neutral" + assert c._classify(70) == "Greed" + assert c._classify(85) == "Extreme Greed" diff --git a/services/news-collector/tests/test_fed.py b/services/news-collector/tests/test_fed.py new file mode 100644 index 0000000..d1a736b --- /dev/null +++ b/services/news-collector/tests/test_fed.py @@ -0,0 +1,37 @@ +"""Tests for Federal Reserve collector.""" + +import pytest +from unittest.mock import AsyncMock, patch +from news_collector.collectors.fed import FedCollector + + +@pytest.fixture +def collector(): + return FedCollector() + + +def test_collector_name(collector): + assert collector.name == "fed" + assert collector.poll_interval == 3600 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_rss(collector): + mock_entries = [ + { + "title": "Federal Reserve issues FOMC statement", + "link": "https://www.federalreserve.gov/newsevents/pressreleases/monetary20260402a.htm", + "published_parsed": (2026, 4, 2, 14, 0, 0, 0, 0, 0), + "summary": "The Federal Open Market Committee decided to maintain the target range...", + }, + ] + with patch.object( + collector, "_fetch_fed_rss", new_callable=AsyncMock, return_value=mock_entries + ): + items = await collector.collect() + assert len(items) == 1 + assert items[0].source == "fed" + assert items[0].category.value == "fed" diff --git a/services/news-collector/tests/test_finnhub.py b/services/news-collector/tests/test_finnhub.py new file mode 100644 index 0000000..a4cf169 --- /dev/null +++ b/services/news-collector/tests/test_finnhub.py @@ -0,0 +1,67 @@ +"""Tests for Finnhub news collector.""" + +import pytest +from unittest.mock import AsyncMock, patch + +from news_collector.collectors.finnhub import FinnhubCollector + + +@pytest.fixture +def collector(): + return FinnhubCollector(api_key="test_key") + + +def test_collector_name(collector): + assert collector.name == "finnhub" + assert collector.poll_interval == 300 + + +async def test_is_available_with_key(collector): + assert await collector.is_available() is True + + +async def test_is_available_without_key(): + c = FinnhubCollector(api_key="") + assert await c.is_available() is False + + +async def test_collect_parses_response(collector): + mock_response = [ + { + "category": "top news", + "datetime": 1711929600, + "headline": "AAPL beats earnings", + "id": 12345, + "related": "AAPL", + "source": "MarketWatch", + "summary": "Apple reported better than expected...", + "url": "https://example.com/article", + }, + { + "category": "top news", + "datetime": 1711929000, + "headline": "Fed holds rates steady", + "id": 12346, + "related": "", + "source": "Reuters", + "summary": "The Federal Reserve...", + "url": "https://example.com/fed", + }, + ] + + with patch.object(collector, "_fetch_news", new_callable=AsyncMock, return_value=mock_response): + items = await collector.collect() + + assert len(items) == 2 + assert items[0].source == "finnhub" + assert items[0].headline == "AAPL beats earnings" + assert items[0].symbols == ["AAPL"] + assert items[0].url == "https://example.com/article" + assert isinstance(items[0].sentiment, float) + assert items[1].symbols == [] + + +async def test_collect_handles_empty_response(collector): + with patch.object(collector, "_fetch_news", new_callable=AsyncMock, return_value=[]): + items = await collector.collect() + assert items == [] diff --git a/services/news-collector/tests/test_main.py b/services/news-collector/tests/test_main.py new file mode 100644 index 0000000..66190dc --- /dev/null +++ b/services/news-collector/tests/test_main.py @@ -0,0 +1,39 @@ +"""Tests for news collector scheduler.""" + +from unittest.mock import AsyncMock, MagicMock +from datetime import datetime, timezone +from shared.models import NewsCategory, NewsItem +from news_collector.main import run_collector_once + + +async def test_run_collector_once_stores_and_publishes(): + mock_item = NewsItem( + source="test", + headline="Test news", + published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + mock_collector = MagicMock() + mock_collector.name = "test" + mock_collector.collect = AsyncMock(return_value=[mock_item]) + mock_db = MagicMock() + mock_db.insert_news_item = AsyncMock() + mock_broker = MagicMock() + mock_broker.publish = AsyncMock() + + count = await run_collector_once(mock_collector, mock_db, mock_broker) + assert count == 1 + mock_db.insert_news_item.assert_called_once_with(mock_item) + mock_broker.publish.assert_called_once() + + +async def test_run_collector_once_handles_empty(): + mock_collector = MagicMock() + mock_collector.name = "test" + mock_collector.collect = AsyncMock(return_value=[]) + mock_db = MagicMock() + mock_broker = MagicMock() + + count = await run_collector_once(mock_collector, mock_db, mock_broker) + assert count == 0 diff --git a/services/news-collector/tests/test_reddit.py b/services/news-collector/tests/test_reddit.py new file mode 100644 index 0000000..440b173 --- /dev/null +++ b/services/news-collector/tests/test_reddit.py @@ -0,0 +1,63 @@ +"""Tests for Reddit collector.""" + +import pytest +from unittest.mock import AsyncMock, patch +from news_collector.collectors.reddit import RedditCollector + + +@pytest.fixture +def collector(): + return RedditCollector() + + +def test_collector_name(collector): + assert collector.name == "reddit" + assert collector.poll_interval == 900 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_posts(collector): + mock_posts = [ + { + "data": { + "title": "NVDA to the moon! AI demand is insane", + "selftext": "Just loaded up on NVDA calls", + "url": "https://reddit.com/r/wallstreetbets/123", + "created_utc": 1711929600, + "score": 500, + "num_comments": 200, + "subreddit": "wallstreetbets", + } + }, + ] + with patch.object( + collector, "_fetch_subreddit", new_callable=AsyncMock, return_value=mock_posts + ): + items = await collector.collect() + assert len(items) >= 1 + assert items[0].source == "reddit" + assert items[0].category.value == "social" + + +async def test_collect_filters_low_score(collector): + mock_posts = [ + { + "data": { + "title": "Random question", + "selftext": "", + "url": "https://reddit.com/456", + "created_utc": 1711929600, + "score": 3, + "num_comments": 1, + "subreddit": "stocks", + } + }, + ] + with patch.object( + collector, "_fetch_subreddit", new_callable=AsyncMock, return_value=mock_posts + ): + items = await collector.collect() + assert items == [] diff --git a/services/news-collector/tests/test_rss.py b/services/news-collector/tests/test_rss.py new file mode 100644 index 0000000..e03250a --- /dev/null +++ b/services/news-collector/tests/test_rss.py @@ -0,0 +1,47 @@ +"""Tests for RSS news collector.""" + +import pytest +from unittest.mock import AsyncMock, patch + +from news_collector.collectors.rss import RSSCollector + + +@pytest.fixture +def collector(): + return RSSCollector() + + +def test_collector_name(collector): + assert collector.name == "rss" + assert collector.poll_interval == 600 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_feed(collector): + mock_feed = { + "entries": [ + { + "title": "NVDA surges on AI demand", + "link": "https://example.com/nvda", + "published_parsed": (2026, 4, 2, 12, 0, 0, 0, 0, 0), + "summary": "Nvidia stock jumped 5%...", + }, + { + "title": "Markets rally on jobs data", + "link": "https://example.com/market", + "published_parsed": (2026, 4, 2, 11, 0, 0, 0, 0, 0), + "summary": "The S&P 500 rose...", + }, + ], + } + + with patch.object(collector, "_fetch_feeds", new_callable=AsyncMock, return_value=[mock_feed]): + items = await collector.collect() + + assert len(items) == 2 + assert items[0].source == "rss" + assert items[0].headline == "NVDA surges on AI demand" + assert isinstance(items[0].sentiment, float) diff --git a/services/news-collector/tests/test_sec_edgar.py b/services/news-collector/tests/test_sec_edgar.py new file mode 100644 index 0000000..5d4f69f --- /dev/null +++ b/services/news-collector/tests/test_sec_edgar.py @@ -0,0 +1,58 @@ +"""Tests for SEC EDGAR filing collector.""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch, MagicMock + +from news_collector.collectors.sec_edgar import SecEdgarCollector + + +@pytest.fixture +def collector(): + return SecEdgarCollector() + + +def test_collector_name(collector): + assert collector.name == "sec_edgar" + assert collector.poll_interval == 1800 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_filings(collector): + mock_response = { + "filings": { + "recent": { + "accessionNumber": ["0001234-26-000001"], + "filingDate": ["2026-04-02"], + "primaryDocument": ["filing.htm"], + "form": ["8-K"], + "primaryDocDescription": ["Current Report"], + } + }, + "tickers": [{"ticker": "AAPL"}], + "name": "Apple Inc", + } + + mock_datetime = MagicMock(spec=datetime) + mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=timezone.utc) + mock_datetime.strptime = datetime.strptime + + with patch.object( + collector, "_fetch_recent_filings", new_callable=AsyncMock, return_value=[mock_response] + ): + with patch("news_collector.collectors.sec_edgar.datetime", mock_datetime): + items = await collector.collect() + + assert len(items) == 1 + assert items[0].source == "sec_edgar" + assert items[0].category.value == "filing" + assert "AAPL" in items[0].symbols + + +async def test_collect_handles_empty(collector): + with patch.object(collector, "_fetch_recent_filings", new_callable=AsyncMock, return_value=[]): + items = await collector.collect() + assert items == [] diff --git a/services/news-collector/tests/test_truth_social.py b/services/news-collector/tests/test_truth_social.py new file mode 100644 index 0000000..91ddb9d --- /dev/null +++ b/services/news-collector/tests/test_truth_social.py @@ -0,0 +1,41 @@ +"""Tests for Truth Social collector.""" + +import pytest +from unittest.mock import AsyncMock, patch +from news_collector.collectors.truth_social import TruthSocialCollector + + +@pytest.fixture +def collector(): + return TruthSocialCollector() + + +def test_collector_name(collector): + assert collector.name == "truth_social" + assert collector.poll_interval == 900 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_posts(collector): + mock_posts = [ + { + "content": "<p>We are imposing 25% tariffs on all steel imports!</p>", + "created_at": "2026-04-02T12:00:00.000Z", + "url": "https://truthsocial.com/@realDonaldTrump/12345", + "id": "12345", + }, + ] + with patch.object(collector, "_fetch_posts", new_callable=AsyncMock, return_value=mock_posts): + items = await collector.collect() + assert len(items) == 1 + assert items[0].source == "truth_social" + assert items[0].category.value == "policy" + + +async def test_collect_handles_empty(collector): + with patch.object(collector, "_fetch_posts", new_callable=AsyncMock, return_value=[]): + items = await collector.collect() + assert items == [] diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index 9fd9c49..15f8588 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -7,3 +7,7 @@ class StrategyConfig(Settings): symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] timeframes: list[str] = ["1m"] strategy_params: dict = {} + selector_final_time: str = "15:30" + selector_max_picks: int = 3 + anthropic_api_key: str = "" + anthropic_model: str = "claude-sonnet-4-20250514" diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 30de528..5a30766 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -1,17 +1,23 @@ """Strategy Engine Service entry point.""" import asyncio +from datetime import datetime from pathlib import Path +import zoneinfo +from shared.alpaca import AlpacaClient from shared.broker import RedisBroker +from shared.db import Database from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier +from shared.sentiment_models import MarketSentiment from strategy_engine.config import StrategyConfig from strategy_engine.engine import StrategyEngine from strategy_engine.plugin_loader import load_strategies +from strategy_engine.stock_selector import StockSelector # The strategies directory lives alongside the installed package STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies" @@ -30,6 +36,40 @@ async def process_symbol(engine: StrategyEngine, stream: str, log) -> None: last_id = await engine.process_once(stream, last_id) +async def run_stock_selector( + selector: StockSelector, + notifier: TelegramNotifier, + db: Database, + config: StrategyConfig, + log, +) -> None: + """Run the stock selector once per day at the configured time.""" + et = zoneinfo.ZoneInfo("America/New_York") + + while True: + now_et = datetime.now(et) + target_hour, target_min = map(int, config.selector_final_time.split(":")) + + if now_et.hour == target_hour and now_et.minute == target_min: + log.info("stock_selector_running") + try: + selections = await selector.select() + if selections: + ms_data = await db.get_latest_market_sentiment() + ms = None + if ms_data: + ms = MarketSentiment(**ms_data) + await notifier.send_stock_selection(selections, ms) + log.info("stock_selector_complete", picks=[s.symbol for s in selections]) + else: + log.info("stock_selector_no_picks") + except Exception as exc: + log.error("stock_selector_error", error=str(exc)) + await asyncio.sleep(120) # Sleep past this minute + else: + await asyncio.sleep(30) + + async def run() -> None: config = StrategyConfig() log = setup_logging("strategy-engine", config.log_level, config.log_format) @@ -41,6 +81,16 @@ async def run() -> None: ) broker = RedisBroker(config.redis_url) + + db = Database(config.database_url) + await db.connect() + + alpaca = AlpacaClient( + api_key=config.alpaca_api_key, + api_secret=config.alpaca_api_secret, + paper=config.alpaca_paper, + ) + strategies = load_strategies(STRATEGIES_DIR) for strategy in strategies: @@ -67,6 +117,20 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) + if config.anthropic_api_key: + selector = StockSelector( + db=db, + broker=broker, + alpaca=alpaca, + anthropic_api_key=config.anthropic_api_key, + anthropic_model=config.anthropic_model, + max_picks=config.selector_max_picks, + ) + tasks.append( + asyncio.create_task(run_stock_selector(selector, notifier, db, config, log)) + ) + log.info("stock_selector_enabled", time=config.selector_final_time) + await asyncio.gather(*tasks) except Exception as exc: log.error("fatal_error", error=str(exc)) @@ -78,6 +142,8 @@ async def run() -> None: metrics.service_up.labels(service="strategy-engine").set(0) await notifier.close() await broker.close() + await alpaca.close() + await db.close() def main() -> None: diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py new file mode 100644 index 0000000..268d557 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -0,0 +1,404 @@ +"""3-stage stock selector engine: sentiment → technical → LLM.""" + +import json +import logging +import re +from datetime import datetime, timezone + +import aiohttp + +from shared.alpaca import AlpacaClient +from shared.broker import RedisBroker +from shared.db import Database +from shared.models import OrderSide +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock + +logger = logging.getLogger(__name__) + +ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + # Try to extract JSON from markdown code block first + code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) + if code_block: + raw = code_block.group(1) + else: + # Try to find a bare JSON array + array_match = re.search(r"\[.*\]", text, re.DOTALL) + if array_match: + raw = array_match.group(0) + else: + raw = text.strip() + + try: + data = json.loads(raw) + if not isinstance(data, list): + return [] + selections = [] + for item in data: + if not isinstance(item, dict): + continue + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + except (json.JSONDecodeError, TypeError): + return [] + + +class SentimentCandidateSource: + """Generates candidates from DB sentiment scores.""" + + def __init__(self, db: Database) -> None: + self._db = db + + async def get_candidates(self) -> list[Candidate]: + rows = await self._db.get_top_symbol_scores(limit=20) + candidates = [] + for row in rows: + composite = float(row.get("composite", 0)) + if composite == 0: + continue + candidates.append( + Candidate( + symbol=row["symbol"], + source="sentiment", + score=composite, + reason=f"composite={composite:.2f}, news_count={row.get('news_count', 0)}", + ) + ) + return candidates + + +class LLMCandidateSource: + """Generates candidates by asking Claude to analyze recent news.""" + + def __init__(self, db: Database, api_key: str, model: str) -> None: + self._db = db + self._api_key = api_key + self._model = model + + async def get_candidates(self) -> list[Candidate]: + news_items = await self._db.get_recent_news(hours=24) + if not news_items: + return [] + + headlines = [] + for item in news_items[:50]: # cap at 50 to stay within context + symbols = item.get("symbols", []) + sym_str = ", ".join(symbols) if symbols else "N/A" + headlines.append(f"[{sym_str}] {item['headline']}") + + prompt = ( + "You are a stock analyst. Given recent news headlines, identify the 5-10 most " + "actionable US stock tickers. Return ONLY a JSON array with objects having: " + "symbol (ticker), direction ('BUY' or 'SELL'), score (0-1), reason (brief).\n\n" + "Headlines:\n" + "\n".join(headlines) + ) + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() + + content = data.get("content", []) + text = "" + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text += block.get("text", "") + + return self._parse_candidates(text) + except Exception as e: + logger.error("LLMCandidateSource error: %s", e) + return [] + + def _parse_candidates(self, text: str) -> list[Candidate]: + code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) + if code_block: + raw = code_block.group(1) + else: + array_match = re.search(r"\[.*\]", text, re.DOTALL) + raw = array_match.group(0) if array_match else text.strip() + + try: + items = json.loads(raw) + if not isinstance(items, list): + return [] + candidates = [] + for item in items: + if not isinstance(item, dict): + continue + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), + ) + ) + return candidates + except (json.JSONDecodeError, TypeError, KeyError): + return [] + + +def _compute_rsi(closes: list[float], period: int = 14) -> float: + """Compute RSI for the last data point.""" + if len(closes) < period + 1: + return 50.0 # neutral if insufficient data + + deltas = [closes[i] - closes[i - 1] for i in range(1, len(closes))] + gains = [d if d > 0 else 0.0 for d in deltas] + losses = [-d if d < 0 else 0.0 for d in deltas] + + avg_gain = sum(gains[:period]) / period + avg_loss = sum(losses[:period]) / period + + for i in range(period, len(deltas)): + avg_gain = (avg_gain * (period - 1) + gains[i]) / period + avg_loss = (avg_loss * (period - 1) + losses[i]) / period + + if avg_loss == 0: + return 100.0 + rs = avg_gain / avg_loss + return 100.0 - (100.0 / (1.0 + rs)) + + +class StockSelector: + """Orchestrates the 3-stage stock selection pipeline.""" + + def __init__( + self, + db: Database, + broker: RedisBroker, + alpaca: AlpacaClient, + anthropic_api_key: str, + anthropic_model: str = "claude-sonnet-4-20250514", + max_picks: int = 3, + ) -> None: + self._db = db + self._broker = broker + self._alpaca = alpaca + self._api_key = anthropic_api_key + self._model = anthropic_model + self._max_picks = max_picks + + async def select(self) -> list[SelectedStock]: + """Run the full 3-stage pipeline and return selected stocks.""" + # Market gate: check sentiment + sentiment_data = await self._db.get_latest_market_sentiment() + if sentiment_data is None: + logger.warning("No market sentiment data; skipping selection") + return [] + + market_sentiment = MarketSentiment(**sentiment_data) + if market_sentiment.market_regime == "risk_off": + logger.info("Market is risk_off; skipping stock selection") + return [] + + # Stage 1: gather candidates from both sources + sentiment_source = SentimentCandidateSource(self._db) + llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + + sentiment_candidates = await sentiment_source.get_candidates() + llm_candidates = await llm_source.get_candidates() + + candidates = self._merge_candidates(sentiment_candidates, llm_candidates) + if not candidates: + logger.info("No candidates found") + return [] + + # Stage 2: technical filter + filtered = await self._technical_filter(candidates) + if not filtered: + logger.info("All candidates filtered out by technical criteria") + return [] + + # Stage 3: LLM final selection + selections = await self._llm_final_select(filtered, market_sentiment) + + # Persist and publish + today = datetime.now(timezone.utc).date() + sentiment_snapshot = { + "fear_greed": market_sentiment.fear_greed, + "market_regime": market_sentiment.market_regime, + "vix": market_sentiment.vix, + } + for stock in selections: + try: + await self._db.insert_stock_selection( + trade_date=today, + symbol=stock.symbol, + side=stock.side.value, + conviction=stock.conviction, + reason=stock.reason, + key_news=stock.key_news, + sentiment_snapshot=sentiment_snapshot, + ) + except Exception as e: + logger.error("Failed to persist selection for %s: %s", stock.symbol, e) + + try: + await self._broker.publish( + "selected_stocks", + { + "symbol": stock.symbol, + "side": stock.side.value, + "conviction": stock.conviction, + "reason": stock.reason, + "key_news": stock.key_news, + "trade_date": str(today), + }, + ) + except Exception as e: + logger.error("Failed to publish selection for %s: %s", stock.symbol, e) + + return selections + + def _merge_candidates( + self, sentiment: list[Candidate], llm: list[Candidate] + ) -> list[Candidate]: + """Deduplicate candidates by symbol, keeping the higher score.""" + by_symbol: dict[str, Candidate] = {} + for c in sentiment + llm: + existing = by_symbol.get(c.symbol) + if existing is None or c.score > existing.score: + by_symbol[c.symbol] = c + return sorted(by_symbol.values(), key=lambda c: c.score, reverse=True) + + async def _technical_filter(self, candidates: list[Candidate]) -> list[Candidate]: + """Filter candidates using RSI, EMA20, and volume criteria.""" + passed = [] + for candidate in candidates: + try: + bars = await self._alpaca.get_bars(candidate.symbol, timeframe="1Day", limit=60) + if len(bars) < 21: + logger.debug("Insufficient bars for %s", candidate.symbol) + continue + + closes = [float(b["c"]) for b in bars] + volumes = [float(b["v"]) for b in bars] + + rsi = _compute_rsi(closes) + if not (30 <= rsi <= 70): + logger.debug("%s RSI=%.1f outside 30-70", candidate.symbol, rsi) + continue + + ema20 = sum(closes[-20:]) / 20 # simple approximation + current_price = closes[-1] + if current_price <= ema20: + logger.debug( + "%s price %.2f <= EMA20 %.2f", candidate.symbol, current_price, ema20 + ) + continue + + avg_volume = sum(volumes[:-1]) / max(len(volumes) - 1, 1) + current_volume = volumes[-1] + if current_volume <= 0.5 * avg_volume: + logger.debug( + "%s volume %.0f <= 50%% avg %.0f", + candidate.symbol, + current_volume, + avg_volume, + ) + continue + + passed.append(candidate) + except Exception as e: + logger.warning("Technical filter error for %s: %s", candidate.symbol, e) + + return passed + + async def _llm_final_select( + self, candidates: list[Candidate], market_sentiment: MarketSentiment + ) -> list[SelectedStock]: + """Ask Claude to pick 2-3 stocks with rationale.""" + candidate_lines = [ + f"- {c.symbol} (source={c.source}, score={c.score:.2f}, reason={c.reason})" + for c in candidates + ] + market_context = ( + f"Fear/Greed: {market_sentiment.fear_greed} ({market_sentiment.fear_greed_label}), " + f"VIX: {market_sentiment.vix}, " + f"Fed stance: {market_sentiment.fed_stance}, " + f"Regime: {market_sentiment.market_regime}" + ) + + prompt = ( + f"You are a portfolio manager. Select 2-3 stocks for today's session.\n\n" + f"Market context: {market_context}\n\n" + f"Candidates (already passed technical filters):\n" + + "\n".join(candidate_lines) + + "\n\n" + "Return ONLY a JSON array with objects having:\n" + " symbol, side ('BUY' or 'SELL'), conviction (0-1), reason (1-2 sentences), " + "key_news (list of 1-3 relevant headlines or facts)\n" + f"Select at most {self._max_picks} stocks." + ) + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() + + content = data.get("content", []) + text = "" + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text += block.get("text", "") + + return _parse_llm_selections(text)[: self._max_picks] + except Exception as e: + logger.error("LLM final select error: %s", e) + return [] diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py index eb31b23..2b909ef 100644 --- a/services/strategy-engine/tests/conftest.py +++ b/services/strategy-engine/tests/conftest.py @@ -7,3 +7,8 @@ from pathlib import Path STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" if str(STRATEGIES_DIR) not in sys.path: sys.path.insert(0, str(STRATEGIES_DIR.parent)) + +# Ensure the worktree's strategy_engine src is preferred over any installed version +WORKTREE_SRC = Path(__file__).parent.parent / "src" +if str(WORKTREE_SRC) not in sys.path: + sys.path.insert(0, str(WORKTREE_SRC)) diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py new file mode 100644 index 0000000..ff9d09c --- /dev/null +++ b/services/strategy-engine/tests/test_stock_selector.py @@ -0,0 +1,80 @@ +"""Tests for stock selector engine.""" + +from unittest.mock import AsyncMock, MagicMock +from datetime import datetime, timezone + + +from strategy_engine.stock_selector import ( + SentimentCandidateSource, + StockSelector, + _parse_llm_selections, +) + + +async def test_sentiment_candidate_source(): + mock_db = MagicMock() + mock_db.get_top_symbol_scores = AsyncMock( + return_value=[ + {"symbol": "AAPL", "composite": 0.8, "news_count": 5}, + {"symbol": "NVDA", "composite": 0.6, "news_count": 3}, + ] + ) + + source = SentimentCandidateSource(mock_db) + candidates = await source.get_candidates() + + assert len(candidates) == 2 + assert candidates[0].symbol == "AAPL" + assert candidates[0].source == "sentiment" + + +def test_parse_llm_selections_valid(): + llm_response = """ + [ + {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]}, + {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]} + ] + """ + selections = _parse_llm_selections(llm_response) + assert len(selections) == 2 + assert selections[0].symbol == "NVDA" + assert selections[0].conviction == 0.85 + + +def test_parse_llm_selections_invalid(): + selections = _parse_llm_selections("not json") + assert selections == [] + + +def test_parse_llm_selections_with_markdown(): + llm_response = """ + Here are my picks: + ```json + [ + {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]} + ] + ``` + """ + selections = _parse_llm_selections(llm_response) + assert len(selections) == 1 + assert selections[0].symbol == "TSLA" + + +async def test_selector_blocks_on_risk_off(): + mock_db = MagicMock() + mock_db.get_latest_market_sentiment = AsyncMock( + return_value={ + "fear_greed": 15, + "fear_greed_label": "Extreme Fear", + "vix": 35.0, + "fed_stance": "neutral", + "market_regime": "risk_off", + "updated_at": datetime.now(timezone.utc), + } + ) + + selector = StockSelector( + db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + result = await selector.select() + assert result == [] |
