diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 14:07:45 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 14:07:45 +0900 |
| commit | 9bf5aef24d83065c093d8cf4e32d789efd07777b (patch) | |
| tree | 138bdfeec1b2f89e53458da4b2862a48b4523173 /shared | |
| parent | 5e2d5887d1f6bc7919948e3f269cfa00e243cb9f (diff) | |
feat: implement SentimentAggregator with freshness decay and composite scoring
Diffstat (limited to 'shared')
| -rw-r--r-- | shared/src/shared/sentiment.py | 106 | ||||
| -rw-r--r-- | shared/tests/test_sentiment_aggregator.py | 69 |
2 files changed, 174 insertions, 1 deletions
diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py index 8213b47..a20227e 100644 --- a/shared/src/shared/sentiment.py +++ b/shared/src/shared/sentiment.py @@ -2,7 +2,9 @@ import logging from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone + +from shared.sentiment_models import SymbolScore logger = logging.getLogger(__name__) @@ -33,3 +35,105 @@ class SentimentData: if self.news_sentiment is not None and self.news_sentiment < -0.5: return True return False + + +def _safe_avg(values: list[float]) -> float: + if not values: + return 0.0 + return sum(values) / len(values) + + +class SentimentAggregator: + """Aggregates per-news sentiment into per-symbol scores.""" + + WEIGHTS = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2} + + CATEGORY_MAP = { + "earnings": "news", + "macro": "news", + "social": "social", + "policy": "policy", + "filing": "filing", + "fed": "policy", + } + + def _freshness_decay(self, published_at: datetime, now: datetime) -> float: + age = now - published_at + hours = age.total_seconds() / 3600 + if hours < 1: + return 1.0 + if hours < 6: + return 0.7 + if hours < 24: + return 0.3 + return 0.0 + + def _compute_composite( + self, + news_score: float, + social_score: float, + policy_score: float, + filing_score: float, + ) -> float: + return ( + news_score * self.WEIGHTS["news"] + + social_score * self.WEIGHTS["social"] + + policy_score * self.WEIGHTS["policy"] + + filing_score * self.WEIGHTS["filing"] + ) + + def aggregate( + self, news_items: list[dict], now: datetime + ) -> dict[str, SymbolScore]: + """Aggregate news items into per-symbol scores. + + Each dict needs: symbols, sentiment, category, published_at. + """ + symbol_data: dict[str, dict] = {} + + for item in news_items: + decay = self._freshness_decay(item["published_at"], now) + if decay == 0.0: + continue + category = item.get("category", "macro") + score_field = self.CATEGORY_MAP.get(category, "news") + weighted_sentiment = item["sentiment"] * decay + + for symbol in item.get("symbols", []): + if symbol not in symbol_data: + symbol_data[symbol] = { + "news_scores": [], + "social_scores": [], + "policy_scores": [], + "filing_scores": [], + "count": 0, + } + symbol_data[symbol][f"{score_field}_scores"].append(weighted_sentiment) + symbol_data[symbol]["count"] += 1 + + result = {} + for symbol, data in symbol_data.items(): + ns = _safe_avg(data["news_scores"]) + ss = _safe_avg(data["social_scores"]) + ps = _safe_avg(data["policy_scores"]) + fs = _safe_avg(data["filing_scores"]) + result[symbol] = SymbolScore( + symbol=symbol, + news_score=ns, + news_count=data["count"], + social_score=ss, + policy_score=ps, + filing_score=fs, + composite=self._compute_composite(ns, ss, ps, fs), + updated_at=now, + ) + return result + + def determine_regime(self, fear_greed: int, vix: float | None) -> str: + if fear_greed <= 20: + return "risk_off" + if vix is not None and vix > 30: + return "risk_off" + if fear_greed >= 60 and (vix is None or vix < 20): + return "risk_on" + return "neutral" diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py new file mode 100644 index 0000000..f9277e7 --- /dev/null +++ b/shared/tests/test_sentiment_aggregator.py @@ -0,0 +1,69 @@ +"""Tests for sentiment aggregator.""" +import pytest +from datetime import datetime, timezone, timedelta +from shared.sentiment import SentimentAggregator + + +@pytest.fixture +def aggregator(): + return SentimentAggregator() + + +def test_freshness_decay_recent(): + a = SentimentAggregator() + now = datetime.now(timezone.utc) + assert a._freshness_decay(now, now) == 1.0 + + +def test_freshness_decay_3_hours(): + a = SentimentAggregator() + now = datetime.now(timezone.utc) + assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7 + + +def test_freshness_decay_12_hours(): + a = SentimentAggregator() + now = datetime.now(timezone.utc) + assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3 + + +def test_freshness_decay_old(): + a = SentimentAggregator() + now = datetime.now(timezone.utc) + assert a._freshness_decay(now - timedelta(days=2), now) == 0.0 + + +def test_compute_composite(): + a = SentimentAggregator() + composite = a._compute_composite(news_score=0.5, social_score=0.3, policy_score=0.8, filing_score=0.2) + expected = 0.5 * 0.3 + 0.3 * 0.2 + 0.8 * 0.3 + 0.2 * 0.2 + assert abs(composite - expected) < 0.001 + + +def test_aggregate_news_by_symbol(aggregator): + now = datetime.now(timezone.utc) + news_items = [ + {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now}, + {"symbols": ["AAPL"], "sentiment": 0.3, "category": "macro", "published_at": now - timedelta(hours=2)}, + {"symbols": ["MSFT"], "sentiment": -0.5, "category": "policy", "published_at": now}, + ] + scores = aggregator.aggregate(news_items, now) + assert "AAPL" in scores + assert "MSFT" in scores + assert scores["AAPL"].news_count == 2 + assert scores["AAPL"].news_score > 0 + assert scores["MSFT"].policy_score < 0 + + +def test_aggregate_empty(aggregator): + now = datetime.now(timezone.utc) + assert aggregator.aggregate([], now) == {} + + +def test_determine_regime(): + a = SentimentAggregator() + assert a.determine_regime(15, None) == "risk_off" + assert a.determine_regime(15, 35.0) == "risk_off" + assert a.determine_regime(50, 35.0) == "risk_off" + assert a.determine_regime(70, 15.0) == "risk_on" + assert a.determine_regime(50, 20.0) == "neutral" |
