diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 13:54:57 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 13:54:57 +0900 |
| commit | ff5a74810ef5f8748b2b30e543968dd9e48512f0 (patch) | |
| tree | 064abadbdf227fa95d9bd37231499cb72cda7e06 /shared/src | |
| parent | b781c8f7c0371a28faf61099b38c28cfed3c46b3 (diff) | |
feat: add NewsEvent, DB methods for news/sentiment/selections
- Add NEWS to EventType enum and NewsEvent class to events.py
- Add insert_news_item, get_recent_news, upsert_symbol_score,
get_top_symbol_scores, upsert_market_sentiment,
get_latest_market_sentiment, insert_stock_selection,
get_stock_selections methods to Database class in db.py
- Add test_news_events.py and test_db_news.py with full coverage
Diffstat (limited to 'shared/src')
| -rw-r--r-- | shared/src/shared/db.py | 247 | ||||
| -rw-r--r-- | shared/src/shared/events.py | 19 |
2 files changed, 262 insertions, 4 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 901e293..55f93b4 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,15 +1,28 @@ """Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" +import json +import uuid from contextlib import asynccontextmanager -from datetime import datetime, timedelta, timezone +from datetime import datetime, date, timedelta, timezone from decimal import Decimal from typing import Optional from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker -from shared.models import Candle, Signal, Order, OrderStatus -from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow +from shared.models import Candle, Signal, Order, OrderStatus, NewsItem +from shared.sentiment_models import SymbolScore, MarketSentiment +from shared.sa_models import ( + Base, + CandleRow, + SignalRow, + OrderRow, + PortfolioSnapshotRow, + NewsItemRow, + SymbolScoreRow, + MarketSentimentRow, + StockSelectionRow, +) class Database: @@ -195,3 +208,231 @@ class Database: } for r in rows ] + + async def insert_news_item(self, item: NewsItem) -> None: + """Insert a NewsItem row, JSON-encoding symbols and raw_data.""" + row = NewsItemRow( + id=item.id, + source=item.source, + headline=item.headline, + summary=item.summary, + url=item.url, + published_at=item.published_at, + symbols=json.dumps(item.symbols), + sentiment=item.sentiment, + category=item.category.value, + raw_data=json.dumps(item.raw_data), + created_at=item.created_at, + ) + async with self._session_factory() as session: + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_recent_news(self, hours: int = 24) -> list[dict]: + """Retrieve news items published in the last N hours.""" + since = datetime.now(timezone.utc) - timedelta(hours=hours) + stmt = ( + select(NewsItemRow) + .where(NewsItemRow.published_at >= since) + .order_by(NewsItemRow.published_at.desc()) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "source": r.source, + "headline": r.headline, + "summary": r.summary, + "url": r.url, + "published_at": r.published_at, + "symbols": json.loads(r.symbols) if r.symbols else [], + "sentiment": r.sentiment, + "category": r.category, + "raw_data": json.loads(r.raw_data) if r.raw_data else {}, + "created_at": r.created_at, + } + for r in rows + ] + + async def upsert_symbol_score(self, score: SymbolScore) -> None: + """Insert or update a SymbolScore row, keyed by symbol.""" + async with self._session_factory() as session: + try: + stmt = select(SymbolScoreRow).where(SymbolScoreRow.symbol == score.symbol) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + if existing is not None: + existing.news_score = score.news_score + existing.news_count = score.news_count + existing.social_score = score.social_score + existing.policy_score = score.policy_score + existing.filing_score = score.filing_score + existing.composite = score.composite + existing.updated_at = score.updated_at + else: + row = SymbolScoreRow( + id=str(uuid.uuid4()), + symbol=score.symbol, + news_score=score.news_score, + news_count=score.news_count, + social_score=score.social_score, + policy_score=score.policy_score, + filing_score=score.filing_score, + composite=score.composite, + updated_at=score.updated_at, + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_top_symbol_scores(self, limit: int = 20) -> list[dict]: + """Retrieve top symbol scores ordered by composite descending.""" + stmt = ( + select(SymbolScoreRow) + .order_by(SymbolScoreRow.composite.desc()) + .limit(limit) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "symbol": r.symbol, + "news_score": r.news_score, + "news_count": r.news_count, + "social_score": r.social_score, + "policy_score": r.policy_score, + "filing_score": r.filing_score, + "composite": r.composite, + "updated_at": r.updated_at, + } + for r in rows + ] + + async def upsert_market_sentiment(self, ms: MarketSentiment) -> None: + """Insert or update the single 'latest' market sentiment row.""" + async with self._session_factory() as session: + try: + stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + if existing is not None: + existing.fear_greed = ms.fear_greed + existing.fear_greed_label = ms.fear_greed_label + existing.vix = ms.vix + existing.fed_stance = ms.fed_stance + existing.market_regime = ms.market_regime + existing.updated_at = ms.updated_at + else: + row = MarketSentimentRow( + id="latest", + fear_greed=ms.fear_greed, + fear_greed_label=ms.fear_greed_label, + vix=ms.vix, + fed_stance=ms.fed_stance, + market_regime=ms.market_regime, + updated_at=ms.updated_at, + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_latest_market_sentiment(self) -> Optional[dict]: + """Retrieve the 'latest' market sentiment row, or None if not found.""" + stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + row = result.scalar_one_or_none() + except Exception: + await session.rollback() + raise + if row is None: + return None + return { + "id": row.id, + "fear_greed": row.fear_greed, + "fear_greed_label": row.fear_greed_label, + "vix": row.vix, + "fed_stance": row.fed_stance, + "market_regime": row.market_regime, + "updated_at": row.updated_at, + } + + async def insert_stock_selection( + self, + trade_date: date, + symbol: str, + side: str, + conviction: float, + reason: str, + key_news: list, + sentiment_snapshot: dict, + ) -> None: + """Insert a stock selection row with JSON-encoded lists/dicts.""" + row = StockSelectionRow( + id=str(uuid.uuid4()), + trade_date=trade_date, + symbol=symbol, + side=side, + conviction=conviction, + reason=reason, + key_news=json.dumps(key_news), + sentiment_snapshot=json.dumps(sentiment_snapshot), + created_at=datetime.now(timezone.utc), + ) + async with self._session_factory() as session: + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_stock_selections(self, trade_date: date) -> list[dict]: + """Retrieve stock selections for a given trade date.""" + stmt = ( + select(StockSelectionRow) + .where(StockSelectionRow.trade_date == trade_date) + .order_by(StockSelectionRow.conviction.desc()) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "trade_date": r.trade_date, + "symbol": r.symbol, + "side": r.side, + "conviction": r.conviction, + "reason": r.reason, + "key_news": json.loads(r.key_news) if r.key_news else [], + "sentiment_snapshot": json.loads(r.sentiment_snapshot) if r.sentiment_snapshot else {}, + "created_at": r.created_at, + } + for r in rows + ] diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py index 72f8865..63f93a2 100644 --- a/shared/src/shared/events.py +++ b/shared/src/shared/events.py @@ -5,13 +5,14 @@ from typing import Any from pydantic import BaseModel -from shared.models import Candle, Signal, Order +from shared.models import Candle, Signal, Order, NewsItem class EventType(str, Enum): CANDLE = "CANDLE" SIGNAL = "SIGNAL" ORDER = "ORDER" + NEWS = "NEWS" class CandleEvent(BaseModel): @@ -59,10 +60,26 @@ class OrderEvent(BaseModel): return cls(type=raw["type"], data=Order(**raw["data"])) +class NewsEvent(BaseModel): + type: EventType = EventType.NEWS + data: NewsItem + + def to_dict(self) -> dict: + return { + "type": self.type, + "data": self.data.model_dump(mode="json"), + } + + @classmethod + def from_raw(cls, raw: dict) -> "NewsEvent": + return cls(type=raw["type"], data=NewsItem(**raw["data"])) + + _EVENT_TYPE_MAP = { EventType.CANDLE: CandleEvent, EventType.SIGNAL: SignalEvent, EventType.ORDER: OrderEvent, + EventType.NEWS: NewsEvent, } |
