summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--shared/src/shared/sentiment.py106
-rw-r--r--shared/tests/test_sentiment_aggregator.py69
2 files changed, 174 insertions, 1 deletions
diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py
index 8213b47..a20227e 100644
--- a/shared/src/shared/sentiment.py
+++ b/shared/src/shared/sentiment.py
@@ -2,7 +2,9 @@
import logging
from dataclasses import dataclass, field
-from datetime import datetime, timezone
+from datetime import datetime, timedelta, timezone
+
+from shared.sentiment_models import SymbolScore
logger = logging.getLogger(__name__)
@@ -33,3 +35,105 @@ class SentimentData:
if self.news_sentiment is not None and self.news_sentiment < -0.5:
return True
return False
+
+
+def _safe_avg(values: list[float]) -> float:
+ if not values:
+ return 0.0
+ return sum(values) / len(values)
+
+
+class SentimentAggregator:
+ """Aggregates per-news sentiment into per-symbol scores."""
+
+ WEIGHTS = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2}
+
+ CATEGORY_MAP = {
+ "earnings": "news",
+ "macro": "news",
+ "social": "social",
+ "policy": "policy",
+ "filing": "filing",
+ "fed": "policy",
+ }
+
+ def _freshness_decay(self, published_at: datetime, now: datetime) -> float:
+ age = now - published_at
+ hours = age.total_seconds() / 3600
+ if hours < 1:
+ return 1.0
+ if hours < 6:
+ return 0.7
+ if hours < 24:
+ return 0.3
+ return 0.0
+
+ def _compute_composite(
+ self,
+ news_score: float,
+ social_score: float,
+ policy_score: float,
+ filing_score: float,
+ ) -> float:
+ return (
+ news_score * self.WEIGHTS["news"]
+ + social_score * self.WEIGHTS["social"]
+ + policy_score * self.WEIGHTS["policy"]
+ + filing_score * self.WEIGHTS["filing"]
+ )
+
+ def aggregate(
+ self, news_items: list[dict], now: datetime
+ ) -> dict[str, SymbolScore]:
+ """Aggregate news items into per-symbol scores.
+
+ Each dict needs: symbols, sentiment, category, published_at.
+ """
+ symbol_data: dict[str, dict] = {}
+
+ for item in news_items:
+ decay = self._freshness_decay(item["published_at"], now)
+ if decay == 0.0:
+ continue
+ category = item.get("category", "macro")
+ score_field = self.CATEGORY_MAP.get(category, "news")
+ weighted_sentiment = item["sentiment"] * decay
+
+ for symbol in item.get("symbols", []):
+ if symbol not in symbol_data:
+ symbol_data[symbol] = {
+ "news_scores": [],
+ "social_scores": [],
+ "policy_scores": [],
+ "filing_scores": [],
+ "count": 0,
+ }
+ symbol_data[symbol][f"{score_field}_scores"].append(weighted_sentiment)
+ symbol_data[symbol]["count"] += 1
+
+ result = {}
+ for symbol, data in symbol_data.items():
+ ns = _safe_avg(data["news_scores"])
+ ss = _safe_avg(data["social_scores"])
+ ps = _safe_avg(data["policy_scores"])
+ fs = _safe_avg(data["filing_scores"])
+ result[symbol] = SymbolScore(
+ symbol=symbol,
+ news_score=ns,
+ news_count=data["count"],
+ social_score=ss,
+ policy_score=ps,
+ filing_score=fs,
+ composite=self._compute_composite(ns, ss, ps, fs),
+ updated_at=now,
+ )
+ return result
+
+ def determine_regime(self, fear_greed: int, vix: float | None) -> str:
+ if fear_greed <= 20:
+ return "risk_off"
+ if vix is not None and vix > 30:
+ return "risk_off"
+ if fear_greed >= 60 and (vix is None or vix < 20):
+ return "risk_on"
+ return "neutral"
diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py
new file mode 100644
index 0000000..f9277e7
--- /dev/null
+++ b/shared/tests/test_sentiment_aggregator.py
@@ -0,0 +1,69 @@
+"""Tests for sentiment aggregator."""
+import pytest
+from datetime import datetime, timezone, timedelta
+from shared.sentiment import SentimentAggregator
+
+
+@pytest.fixture
+def aggregator():
+ return SentimentAggregator()
+
+
+def test_freshness_decay_recent():
+ a = SentimentAggregator()
+ now = datetime.now(timezone.utc)
+ assert a._freshness_decay(now, now) == 1.0
+
+
+def test_freshness_decay_3_hours():
+ a = SentimentAggregator()
+ now = datetime.now(timezone.utc)
+ assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7
+
+
+def test_freshness_decay_12_hours():
+ a = SentimentAggregator()
+ now = datetime.now(timezone.utc)
+ assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3
+
+
+def test_freshness_decay_old():
+ a = SentimentAggregator()
+ now = datetime.now(timezone.utc)
+ assert a._freshness_decay(now - timedelta(days=2), now) == 0.0
+
+
+def test_compute_composite():
+ a = SentimentAggregator()
+ composite = a._compute_composite(news_score=0.5, social_score=0.3, policy_score=0.8, filing_score=0.2)
+ expected = 0.5 * 0.3 + 0.3 * 0.2 + 0.8 * 0.3 + 0.2 * 0.2
+ assert abs(composite - expected) < 0.001
+
+
+def test_aggregate_news_by_symbol(aggregator):
+ now = datetime.now(timezone.utc)
+ news_items = [
+ {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now},
+ {"symbols": ["AAPL"], "sentiment": 0.3, "category": "macro", "published_at": now - timedelta(hours=2)},
+ {"symbols": ["MSFT"], "sentiment": -0.5, "category": "policy", "published_at": now},
+ ]
+ scores = aggregator.aggregate(news_items, now)
+ assert "AAPL" in scores
+ assert "MSFT" in scores
+ assert scores["AAPL"].news_count == 2
+ assert scores["AAPL"].news_score > 0
+ assert scores["MSFT"].policy_score < 0
+
+
+def test_aggregate_empty(aggregator):
+ now = datetime.now(timezone.utc)
+ assert aggregator.aggregate([], now) == {}
+
+
+def test_determine_regime():
+ a = SentimentAggregator()
+ assert a.determine_regime(15, None) == "risk_off"
+ assert a.determine_regime(15, 35.0) == "risk_off"
+ assert a.determine_regime(50, 35.0) == "risk_off"
+ assert a.determine_regime(70, 15.0) == "risk_on"
+ assert a.determine_regime(50, 20.0) == "neutral"