From ff5a74810ef5f8748b2b30e543968dd9e48512f0 Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:54:57 +0900 Subject: feat: add NewsEvent, DB methods for news/sentiment/selections - Add NEWS to EventType enum and NewsEvent class to events.py - Add insert_news_item, get_recent_news, upsert_symbol_score, get_top_symbol_scores, upsert_market_sentiment, get_latest_market_sentiment, insert_stock_selection, get_stock_selections methods to Database class in db.py - Add test_news_events.py and test_db_news.py with full coverage --- shared/src/shared/db.py | 247 ++++++++++++++++++++++++++++++++++++++- shared/src/shared/events.py | 19 ++- shared/tests/test_db_news.py | 80 +++++++++++++ shared/tests/test_news_events.py | 56 +++++++++ 4 files changed, 398 insertions(+), 4 deletions(-) create mode 100644 shared/tests/test_db_news.py create mode 100644 shared/tests/test_news_events.py diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 901e293..55f93b4 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,15 +1,28 @@ """Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" +import json +import uuid from contextlib import asynccontextmanager -from datetime import datetime, timedelta, timezone +from datetime import datetime, date, timedelta, timezone from decimal import Decimal from typing import Optional from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker -from shared.models import Candle, Signal, Order, OrderStatus -from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow +from shared.models import Candle, Signal, Order, OrderStatus, NewsItem +from shared.sentiment_models import SymbolScore, MarketSentiment +from shared.sa_models import ( + Base, + CandleRow, + SignalRow, + OrderRow, + PortfolioSnapshotRow, + NewsItemRow, + SymbolScoreRow, + MarketSentimentRow, + StockSelectionRow, +) class Database: @@ -195,3 +208,231 @@ class Database: } for r in rows ] + + async def insert_news_item(self, item: NewsItem) -> None: + """Insert a NewsItem row, JSON-encoding symbols and raw_data.""" + row = NewsItemRow( + id=item.id, + source=item.source, + headline=item.headline, + summary=item.summary, + url=item.url, + published_at=item.published_at, + symbols=json.dumps(item.symbols), + sentiment=item.sentiment, + category=item.category.value, + raw_data=json.dumps(item.raw_data), + created_at=item.created_at, + ) + async with self._session_factory() as session: + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_recent_news(self, hours: int = 24) -> list[dict]: + """Retrieve news items published in the last N hours.""" + since = datetime.now(timezone.utc) - timedelta(hours=hours) + stmt = ( + select(NewsItemRow) + .where(NewsItemRow.published_at >= since) + .order_by(NewsItemRow.published_at.desc()) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "source": r.source, + "headline": r.headline, + "summary": r.summary, + "url": r.url, + "published_at": r.published_at, + "symbols": json.loads(r.symbols) if r.symbols else [], + "sentiment": r.sentiment, + "category": r.category, + "raw_data": json.loads(r.raw_data) if r.raw_data else {}, + "created_at": r.created_at, + } + for r in rows + ] + + async def upsert_symbol_score(self, score: SymbolScore) -> None: + """Insert or update a SymbolScore row, keyed by symbol.""" + async with self._session_factory() as session: + try: + stmt = select(SymbolScoreRow).where(SymbolScoreRow.symbol == score.symbol) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + if existing is not None: + existing.news_score = score.news_score + existing.news_count = score.news_count + existing.social_score = score.social_score + existing.policy_score = score.policy_score + existing.filing_score = score.filing_score + existing.composite = score.composite + existing.updated_at = score.updated_at + else: + row = SymbolScoreRow( + id=str(uuid.uuid4()), + symbol=score.symbol, + news_score=score.news_score, + news_count=score.news_count, + social_score=score.social_score, + policy_score=score.policy_score, + filing_score=score.filing_score, + composite=score.composite, + updated_at=score.updated_at, + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_top_symbol_scores(self, limit: int = 20) -> list[dict]: + """Retrieve top symbol scores ordered by composite descending.""" + stmt = ( + select(SymbolScoreRow) + .order_by(SymbolScoreRow.composite.desc()) + .limit(limit) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "symbol": r.symbol, + "news_score": r.news_score, + "news_count": r.news_count, + "social_score": r.social_score, + "policy_score": r.policy_score, + "filing_score": r.filing_score, + "composite": r.composite, + "updated_at": r.updated_at, + } + for r in rows + ] + + async def upsert_market_sentiment(self, ms: MarketSentiment) -> None: + """Insert or update the single 'latest' market sentiment row.""" + async with self._session_factory() as session: + try: + stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + if existing is not None: + existing.fear_greed = ms.fear_greed + existing.fear_greed_label = ms.fear_greed_label + existing.vix = ms.vix + existing.fed_stance = ms.fed_stance + existing.market_regime = ms.market_regime + existing.updated_at = ms.updated_at + else: + row = MarketSentimentRow( + id="latest", + fear_greed=ms.fear_greed, + fear_greed_label=ms.fear_greed_label, + vix=ms.vix, + fed_stance=ms.fed_stance, + market_regime=ms.market_regime, + updated_at=ms.updated_at, + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_latest_market_sentiment(self) -> Optional[dict]: + """Retrieve the 'latest' market sentiment row, or None if not found.""" + stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + row = result.scalar_one_or_none() + except Exception: + await session.rollback() + raise + if row is None: + return None + return { + "id": row.id, + "fear_greed": row.fear_greed, + "fear_greed_label": row.fear_greed_label, + "vix": row.vix, + "fed_stance": row.fed_stance, + "market_regime": row.market_regime, + "updated_at": row.updated_at, + } + + async def insert_stock_selection( + self, + trade_date: date, + symbol: str, + side: str, + conviction: float, + reason: str, + key_news: list, + sentiment_snapshot: dict, + ) -> None: + """Insert a stock selection row with JSON-encoded lists/dicts.""" + row = StockSelectionRow( + id=str(uuid.uuid4()), + trade_date=trade_date, + symbol=symbol, + side=side, + conviction=conviction, + reason=reason, + key_news=json.dumps(key_news), + sentiment_snapshot=json.dumps(sentiment_snapshot), + created_at=datetime.now(timezone.utc), + ) + async with self._session_factory() as session: + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_stock_selections(self, trade_date: date) -> list[dict]: + """Retrieve stock selections for a given trade date.""" + stmt = ( + select(StockSelectionRow) + .where(StockSelectionRow.trade_date == trade_date) + .order_by(StockSelectionRow.conviction.desc()) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "trade_date": r.trade_date, + "symbol": r.symbol, + "side": r.side, + "conviction": r.conviction, + "reason": r.reason, + "key_news": json.loads(r.key_news) if r.key_news else [], + "sentiment_snapshot": json.loads(r.sentiment_snapshot) if r.sentiment_snapshot else {}, + "created_at": r.created_at, + } + for r in rows + ] diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py index 72f8865..63f93a2 100644 --- a/shared/src/shared/events.py +++ b/shared/src/shared/events.py @@ -5,13 +5,14 @@ from typing import Any from pydantic import BaseModel -from shared.models import Candle, Signal, Order +from shared.models import Candle, Signal, Order, NewsItem class EventType(str, Enum): CANDLE = "CANDLE" SIGNAL = "SIGNAL" ORDER = "ORDER" + NEWS = "NEWS" class CandleEvent(BaseModel): @@ -59,10 +60,26 @@ class OrderEvent(BaseModel): return cls(type=raw["type"], data=Order(**raw["data"])) +class NewsEvent(BaseModel): + type: EventType = EventType.NEWS + data: NewsItem + + def to_dict(self) -> dict: + return { + "type": self.type, + "data": self.data.model_dump(mode="json"), + } + + @classmethod + def from_raw(cls, raw: dict) -> "NewsEvent": + return cls(type=raw["type"], data=NewsItem(**raw["data"])) + + _EVENT_TYPE_MAP = { EventType.CANDLE: CandleEvent, EventType.SIGNAL: SignalEvent, EventType.ORDER: OrderEvent, + EventType.NEWS: NewsEvent, } diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py new file mode 100644 index 0000000..f13cf1e --- /dev/null +++ b/shared/tests/test_db_news.py @@ -0,0 +1,80 @@ +"""Tests for database news/sentiment methods. Uses in-memory SQLite.""" + +import json +import uuid +import pytest +from datetime import datetime, date, timezone + +from shared.db import Database +from shared.models import NewsItem, NewsCategory +from shared.sentiment_models import SymbolScore, MarketSentiment + + +@pytest.fixture +async def db(): + database = Database("sqlite+aiosqlite://") + await database.connect() + yield database + await database.close() + + +async def test_insert_and_get_news_items(db): + item = NewsItem( + source="finnhub", + headline="AAPL earnings beat", + published_at=datetime(2026, 4, 2, 12, 0, tzinfo=timezone.utc), + sentiment=0.8, + category=NewsCategory.EARNINGS, + symbols=["AAPL"], + ) + await db.insert_news_item(item) + items = await db.get_recent_news(hours=24) + assert len(items) == 1 + assert items[0]["headline"] == "AAPL earnings beat" + + +async def test_upsert_symbol_score(db): + score = SymbolScore( + symbol="AAPL", + news_score=0.5, + news_count=10, + social_score=0.3, + policy_score=0.0, + filing_score=0.2, + composite=0.3, + updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + ) + await db.upsert_symbol_score(score) + scores = await db.get_top_symbol_scores(limit=5) + assert len(scores) == 1 + assert scores[0]["symbol"] == "AAPL" + + +async def test_upsert_market_sentiment(db): + ms = MarketSentiment( + fear_greed=55, + fear_greed_label="Neutral", + vix=18.2, + fed_stance="neutral", + market_regime="neutral", + updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + ) + await db.upsert_market_sentiment(ms) + result = await db.get_latest_market_sentiment() + assert result is not None + assert result["fear_greed"] == 55 + + +async def test_insert_stock_selection(db): + await db.insert_stock_selection( + trade_date=date(2026, 4, 2), + symbol="NVDA", + side="BUY", + conviction=0.85, + reason="CHIPS Act", + key_news=["Trump signs CHIPS expansion"], + sentiment_snapshot={"composite": 0.8}, + ) + selections = await db.get_stock_selections(date(2026, 4, 2)) + assert len(selections) == 1 + assert selections[0]["symbol"] == "NVDA" diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py new file mode 100644 index 0000000..384796a --- /dev/null +++ b/shared/tests/test_news_events.py @@ -0,0 +1,56 @@ +"""Tests for NewsEvent.""" + +from datetime import datetime, timezone + +from shared.models import NewsCategory, NewsItem +from shared.events import NewsEvent, EventType, Event + + +def test_news_event_to_dict(): + item = NewsItem( + source="finnhub", + headline="Test", + published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + event = NewsEvent(data=item) + d = event.to_dict() + assert d["type"] == EventType.NEWS + assert d["data"]["source"] == "finnhub" + + +def test_news_event_from_raw(): + raw = { + "type": "NEWS", + "data": { + "id": "abc", + "source": "rss", + "headline": "Test headline", + "published_at": "2026-04-02T00:00:00+00:00", + "sentiment": 0.3, + "category": "earnings", + "symbols": ["AAPL"], + "raw_data": {}, + }, + } + event = NewsEvent.from_raw(raw) + assert event.data.source == "rss" + assert event.data.symbols == ["AAPL"] + + +def test_event_dispatcher_news(): + raw = { + "type": "NEWS", + "data": { + "id": "abc", + "source": "finnhub", + "headline": "Test", + "published_at": "2026-04-02T00:00:00+00:00", + "sentiment": 0.0, + "category": "macro", + "raw_data": {}, + }, + } + event = Event.from_dict(raw) + assert isinstance(event, NewsEvent) -- cgit v1.2.3