summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/db.py247
-rw-r--r--shared/src/shared/events.py19
-rw-r--r--shared/tests/test_db_news.py80
-rw-r--r--shared/tests/test_news_events.py56
4 files changed, 398 insertions, 4 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 901e293..55f93b4 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,15 +1,28 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
+import json
+import uuid
from contextlib import asynccontextmanager
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, date, timedelta, timezone
from decimal import Decimal
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
-from shared.models import Candle, Signal, Order, OrderStatus
-from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow
+from shared.models import Candle, Signal, Order, OrderStatus, NewsItem
+from shared.sentiment_models import SymbolScore, MarketSentiment
+from shared.sa_models import (
+ Base,
+ CandleRow,
+ SignalRow,
+ OrderRow,
+ PortfolioSnapshotRow,
+ NewsItemRow,
+ SymbolScoreRow,
+ MarketSentimentRow,
+ StockSelectionRow,
+)
class Database:
@@ -195,3 +208,231 @@ class Database:
}
for r in rows
]
+
+ async def insert_news_item(self, item: NewsItem) -> None:
+ """Insert a NewsItem row, JSON-encoding symbols and raw_data."""
+ row = NewsItemRow(
+ id=item.id,
+ source=item.source,
+ headline=item.headline,
+ summary=item.summary,
+ url=item.url,
+ published_at=item.published_at,
+ symbols=json.dumps(item.symbols),
+ sentiment=item.sentiment,
+ category=item.category.value,
+ raw_data=json.dumps(item.raw_data),
+ created_at=item.created_at,
+ )
+ async with self._session_factory() as session:
+ try:
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_recent_news(self, hours: int = 24) -> list[dict]:
+ """Retrieve news items published in the last N hours."""
+ since = datetime.now(timezone.utc) - timedelta(hours=hours)
+ stmt = (
+ select(NewsItemRow)
+ .where(NewsItemRow.published_at >= since)
+ .order_by(NewsItemRow.published_at.desc())
+ )
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ except Exception:
+ await session.rollback()
+ raise
+ return [
+ {
+ "id": r.id,
+ "source": r.source,
+ "headline": r.headline,
+ "summary": r.summary,
+ "url": r.url,
+ "published_at": r.published_at,
+ "symbols": json.loads(r.symbols) if r.symbols else [],
+ "sentiment": r.sentiment,
+ "category": r.category,
+ "raw_data": json.loads(r.raw_data) if r.raw_data else {},
+ "created_at": r.created_at,
+ }
+ for r in rows
+ ]
+
+ async def upsert_symbol_score(self, score: SymbolScore) -> None:
+ """Insert or update a SymbolScore row, keyed by symbol."""
+ async with self._session_factory() as session:
+ try:
+ stmt = select(SymbolScoreRow).where(SymbolScoreRow.symbol == score.symbol)
+ result = await session.execute(stmt)
+ existing = result.scalar_one_or_none()
+ if existing is not None:
+ existing.news_score = score.news_score
+ existing.news_count = score.news_count
+ existing.social_score = score.social_score
+ existing.policy_score = score.policy_score
+ existing.filing_score = score.filing_score
+ existing.composite = score.composite
+ existing.updated_at = score.updated_at
+ else:
+ row = SymbolScoreRow(
+ id=str(uuid.uuid4()),
+ symbol=score.symbol,
+ news_score=score.news_score,
+ news_count=score.news_count,
+ social_score=score.social_score,
+ policy_score=score.policy_score,
+ filing_score=score.filing_score,
+ composite=score.composite,
+ updated_at=score.updated_at,
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_top_symbol_scores(self, limit: int = 20) -> list[dict]:
+ """Retrieve top symbol scores ordered by composite descending."""
+ stmt = (
+ select(SymbolScoreRow)
+ .order_by(SymbolScoreRow.composite.desc())
+ .limit(limit)
+ )
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ except Exception:
+ await session.rollback()
+ raise
+ return [
+ {
+ "id": r.id,
+ "symbol": r.symbol,
+ "news_score": r.news_score,
+ "news_count": r.news_count,
+ "social_score": r.social_score,
+ "policy_score": r.policy_score,
+ "filing_score": r.filing_score,
+ "composite": r.composite,
+ "updated_at": r.updated_at,
+ }
+ for r in rows
+ ]
+
+ async def upsert_market_sentiment(self, ms: MarketSentiment) -> None:
+ """Insert or update the single 'latest' market sentiment row."""
+ async with self._session_factory() as session:
+ try:
+ stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest")
+ result = await session.execute(stmt)
+ existing = result.scalar_one_or_none()
+ if existing is not None:
+ existing.fear_greed = ms.fear_greed
+ existing.fear_greed_label = ms.fear_greed_label
+ existing.vix = ms.vix
+ existing.fed_stance = ms.fed_stance
+ existing.market_regime = ms.market_regime
+ existing.updated_at = ms.updated_at
+ else:
+ row = MarketSentimentRow(
+ id="latest",
+ fear_greed=ms.fear_greed,
+ fear_greed_label=ms.fear_greed_label,
+ vix=ms.vix,
+ fed_stance=ms.fed_stance,
+ market_regime=ms.market_regime,
+ updated_at=ms.updated_at,
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_latest_market_sentiment(self) -> Optional[dict]:
+ """Retrieve the 'latest' market sentiment row, or None if not found."""
+ stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest")
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ row = result.scalar_one_or_none()
+ except Exception:
+ await session.rollback()
+ raise
+ if row is None:
+ return None
+ return {
+ "id": row.id,
+ "fear_greed": row.fear_greed,
+ "fear_greed_label": row.fear_greed_label,
+ "vix": row.vix,
+ "fed_stance": row.fed_stance,
+ "market_regime": row.market_regime,
+ "updated_at": row.updated_at,
+ }
+
+ async def insert_stock_selection(
+ self,
+ trade_date: date,
+ symbol: str,
+ side: str,
+ conviction: float,
+ reason: str,
+ key_news: list,
+ sentiment_snapshot: dict,
+ ) -> None:
+ """Insert a stock selection row with JSON-encoded lists/dicts."""
+ row = StockSelectionRow(
+ id=str(uuid.uuid4()),
+ trade_date=trade_date,
+ symbol=symbol,
+ side=side,
+ conviction=conviction,
+ reason=reason,
+ key_news=json.dumps(key_news),
+ sentiment_snapshot=json.dumps(sentiment_snapshot),
+ created_at=datetime.now(timezone.utc),
+ )
+ async with self._session_factory() as session:
+ try:
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_stock_selections(self, trade_date: date) -> list[dict]:
+ """Retrieve stock selections for a given trade date."""
+ stmt = (
+ select(StockSelectionRow)
+ .where(StockSelectionRow.trade_date == trade_date)
+ .order_by(StockSelectionRow.conviction.desc())
+ )
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ except Exception:
+ await session.rollback()
+ raise
+ return [
+ {
+ "id": r.id,
+ "trade_date": r.trade_date,
+ "symbol": r.symbol,
+ "side": r.side,
+ "conviction": r.conviction,
+ "reason": r.reason,
+ "key_news": json.loads(r.key_news) if r.key_news else [],
+ "sentiment_snapshot": json.loads(r.sentiment_snapshot) if r.sentiment_snapshot else {},
+ "created_at": r.created_at,
+ }
+ for r in rows
+ ]
diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py
index 72f8865..63f93a2 100644
--- a/shared/src/shared/events.py
+++ b/shared/src/shared/events.py
@@ -5,13 +5,14 @@ from typing import Any
from pydantic import BaseModel
-from shared.models import Candle, Signal, Order
+from shared.models import Candle, Signal, Order, NewsItem
class EventType(str, Enum):
CANDLE = "CANDLE"
SIGNAL = "SIGNAL"
ORDER = "ORDER"
+ NEWS = "NEWS"
class CandleEvent(BaseModel):
@@ -59,10 +60,26 @@ class OrderEvent(BaseModel):
return cls(type=raw["type"], data=Order(**raw["data"]))
+class NewsEvent(BaseModel):
+ type: EventType = EventType.NEWS
+ data: NewsItem
+
+ def to_dict(self) -> dict:
+ return {
+ "type": self.type,
+ "data": self.data.model_dump(mode="json"),
+ }
+
+ @classmethod
+ def from_raw(cls, raw: dict) -> "NewsEvent":
+ return cls(type=raw["type"], data=NewsItem(**raw["data"]))
+
+
_EVENT_TYPE_MAP = {
EventType.CANDLE: CandleEvent,
EventType.SIGNAL: SignalEvent,
EventType.ORDER: OrderEvent,
+ EventType.NEWS: NewsEvent,
}
diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py
new file mode 100644
index 0000000..f13cf1e
--- /dev/null
+++ b/shared/tests/test_db_news.py
@@ -0,0 +1,80 @@
+"""Tests for database news/sentiment methods. Uses in-memory SQLite."""
+
+import json
+import uuid
+import pytest
+from datetime import datetime, date, timezone
+
+from shared.db import Database
+from shared.models import NewsItem, NewsCategory
+from shared.sentiment_models import SymbolScore, MarketSentiment
+
+
+@pytest.fixture
+async def db():
+ database = Database("sqlite+aiosqlite://")
+ await database.connect()
+ yield database
+ await database.close()
+
+
+async def test_insert_and_get_news_items(db):
+ item = NewsItem(
+ source="finnhub",
+ headline="AAPL earnings beat",
+ published_at=datetime(2026, 4, 2, 12, 0, tzinfo=timezone.utc),
+ sentiment=0.8,
+ category=NewsCategory.EARNINGS,
+ symbols=["AAPL"],
+ )
+ await db.insert_news_item(item)
+ items = await db.get_recent_news(hours=24)
+ assert len(items) == 1
+ assert items[0]["headline"] == "AAPL earnings beat"
+
+
+async def test_upsert_symbol_score(db):
+ score = SymbolScore(
+ symbol="AAPL",
+ news_score=0.5,
+ news_count=10,
+ social_score=0.3,
+ policy_score=0.0,
+ filing_score=0.2,
+ composite=0.3,
+ updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ )
+ await db.upsert_symbol_score(score)
+ scores = await db.get_top_symbol_scores(limit=5)
+ assert len(scores) == 1
+ assert scores[0]["symbol"] == "AAPL"
+
+
+async def test_upsert_market_sentiment(db):
+ ms = MarketSentiment(
+ fear_greed=55,
+ fear_greed_label="Neutral",
+ vix=18.2,
+ fed_stance="neutral",
+ market_regime="neutral",
+ updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ )
+ await db.upsert_market_sentiment(ms)
+ result = await db.get_latest_market_sentiment()
+ assert result is not None
+ assert result["fear_greed"] == 55
+
+
+async def test_insert_stock_selection(db):
+ await db.insert_stock_selection(
+ trade_date=date(2026, 4, 2),
+ symbol="NVDA",
+ side="BUY",
+ conviction=0.85,
+ reason="CHIPS Act",
+ key_news=["Trump signs CHIPS expansion"],
+ sentiment_snapshot={"composite": 0.8},
+ )
+ selections = await db.get_stock_selections(date(2026, 4, 2))
+ assert len(selections) == 1
+ assert selections[0]["symbol"] == "NVDA"
diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py
new file mode 100644
index 0000000..384796a
--- /dev/null
+++ b/shared/tests/test_news_events.py
@@ -0,0 +1,56 @@
+"""Tests for NewsEvent."""
+
+from datetime import datetime, timezone
+
+from shared.models import NewsCategory, NewsItem
+from shared.events import NewsEvent, EventType, Event
+
+
+def test_news_event_to_dict():
+ item = NewsItem(
+ source="finnhub",
+ headline="Test",
+ published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ sentiment=0.5,
+ category=NewsCategory.MACRO,
+ )
+ event = NewsEvent(data=item)
+ d = event.to_dict()
+ assert d["type"] == EventType.NEWS
+ assert d["data"]["source"] == "finnhub"
+
+
+def test_news_event_from_raw():
+ raw = {
+ "type": "NEWS",
+ "data": {
+ "id": "abc",
+ "source": "rss",
+ "headline": "Test headline",
+ "published_at": "2026-04-02T00:00:00+00:00",
+ "sentiment": 0.3,
+ "category": "earnings",
+ "symbols": ["AAPL"],
+ "raw_data": {},
+ },
+ }
+ event = NewsEvent.from_raw(raw)
+ assert event.data.source == "rss"
+ assert event.data.symbols == ["AAPL"]
+
+
+def test_event_dispatcher_news():
+ raw = {
+ "type": "NEWS",
+ "data": {
+ "id": "abc",
+ "source": "finnhub",
+ "headline": "Test",
+ "published_at": "2026-04-02T00:00:00+00:00",
+ "sentiment": 0.0,
+ "category": "macro",
+ "raw_data": {},
+ },
+ }
+ event = Event.from_dict(raw)
+ assert isinstance(event, NewsEvent)