diff options
Diffstat (limited to 'shared')
33 files changed, 1526 insertions, 373 deletions
diff --git a/shared/alembic/versions/001_initial_schema.py b/shared/alembic/versions/001_initial_schema.py index 2bdaafc..7b744ee 100644 --- a/shared/alembic/versions/001_initial_schema.py +++ b/shared/alembic/versions/001_initial_schema.py @@ -5,16 +5,16 @@ Revises: Create Date: 2026-04-01 """ -from typing import Sequence, Union +from collections.abc import Sequence -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "001" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: diff --git a/shared/alembic/versions/002_news_sentiment_tables.py b/shared/alembic/versions/002_news_sentiment_tables.py new file mode 100644 index 0000000..d85a634 --- /dev/null +++ b/shared/alembic/versions/002_news_sentiment_tables.py @@ -0,0 +1,84 @@ +"""Add news, sentiment, and stock selection tables + +Revision ID: 002 +Revises: 001 +Create Date: 2026-04-02 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +revision: str = "002" +down_revision: str | None = "001" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "news_items", + sa.Column("id", sa.Text, primary_key=True), + sa.Column("source", sa.Text, nullable=False), + sa.Column("headline", sa.Text, nullable=False), + sa.Column("summary", sa.Text), + sa.Column("url", sa.Text), + sa.Column("published_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("symbols", sa.Text), + sa.Column("sentiment", sa.Float, nullable=False), + sa.Column("category", sa.Text, nullable=False), + sa.Column("raw_data", sa.Text), + sa.Column( + "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ), + ) + op.create_index("idx_news_items_published", "news_items", ["published_at"]) + op.create_index("idx_news_items_source", "news_items", ["source"]) + + op.create_table( + "symbol_scores", + sa.Column("id", sa.Text, primary_key=True), + sa.Column("symbol", sa.Text, nullable=False, unique=True), + sa.Column("news_score", sa.Float, nullable=False, server_default="0"), + sa.Column("news_count", sa.Integer, nullable=False, server_default="0"), + sa.Column("social_score", sa.Float, nullable=False, server_default="0"), + sa.Column("policy_score", sa.Float, nullable=False, server_default="0"), + sa.Column("filing_score", sa.Float, nullable=False, server_default="0"), + sa.Column("composite", sa.Float, nullable=False, server_default="0"), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + + op.create_table( + "market_sentiment", + sa.Column("id", sa.Text, primary_key=True), + sa.Column("fear_greed", sa.Integer, nullable=False), + sa.Column("fear_greed_label", sa.Text, nullable=False), + sa.Column("vix", sa.Float), + sa.Column("fed_stance", sa.Text, nullable=False, server_default="neutral"), + sa.Column("market_regime", sa.Text, nullable=False, server_default="neutral"), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + + op.create_table( + "stock_selections", + sa.Column("id", sa.Text, primary_key=True), + sa.Column("trade_date", sa.Date, nullable=False), + sa.Column("symbol", sa.Text, nullable=False), + sa.Column("side", sa.Text, nullable=False), + sa.Column("conviction", sa.Float, nullable=False), + sa.Column("reason", sa.Text, nullable=False), + sa.Column("key_news", sa.Text), + sa.Column("sentiment_snapshot", sa.Text), + sa.Column( + "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ), + ) + op.create_index("idx_stock_selections_date", "stock_selections", ["trade_date"]) + + +def downgrade() -> None: + op.drop_table("stock_selections") + op.drop_table("market_sentiment") + op.drop_table("symbol_scores") + op.drop_table("news_items") diff --git a/shared/alembic/versions/003_add_missing_indexes.py b/shared/alembic/versions/003_add_missing_indexes.py new file mode 100644 index 0000000..7a252d4 --- /dev/null +++ b/shared/alembic/versions/003_add_missing_indexes.py @@ -0,0 +1,35 @@ +"""Add missing indexes for common query patterns. + +Revision ID: 003 +Revises: 002 +Create Date: 2026-04-02 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "003" +down_revision: str | None = "002" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_index("idx_signals_symbol_created", "signals", ["symbol", "created_at"]) + op.create_index( + "idx_orders_symbol_status_created", "orders", ["symbol", "status", "created_at"] + ) + op.create_index("idx_trades_order_id", "trades", ["order_id"]) + op.create_index("idx_trades_symbol_traded", "trades", ["symbol", "traded_at"]) + op.create_index("idx_portfolio_snapshots_at", "portfolio_snapshots", ["snapshot_at"]) + op.create_index("idx_symbol_scores_symbol", "symbol_scores", ["symbol"], unique=True) + + +def downgrade() -> None: + op.drop_index("idx_symbol_scores_symbol", table_name="symbol_scores") + op.drop_index("idx_portfolio_snapshots_at", table_name="portfolio_snapshots") + op.drop_index("idx_trades_symbol_traded", table_name="trades") + op.drop_index("idx_trades_order_id", table_name="trades") + op.drop_index("idx_orders_symbol_status_created", table_name="orders") + op.drop_index("idx_signals_symbol_created", table_name="signals") diff --git a/shared/alembic/versions/004_add_signal_detail_columns.py b/shared/alembic/versions/004_add_signal_detail_columns.py new file mode 100644 index 0000000..4009b6e --- /dev/null +++ b/shared/alembic/versions/004_add_signal_detail_columns.py @@ -0,0 +1,25 @@ +"""Add conviction, stop_loss, take_profit columns to signals table. + +Revision ID: 004 +Revises: 003 +""" + +import sqlalchemy as sa +from alembic import op + +revision = "004" +down_revision = "003" + + +def upgrade(): + op.add_column( + "signals", sa.Column("conviction", sa.Float, nullable=False, server_default="1.0") + ) + op.add_column("signals", sa.Column("stop_loss", sa.Numeric, nullable=True)) + op.add_column("signals", sa.Column("take_profit", sa.Numeric, nullable=True)) + + +def downgrade(): + op.drop_column("signals", "take_profit") + op.drop_column("signals", "stop_loss") + op.drop_column("signals", "conviction") diff --git a/shared/pyproject.toml b/shared/pyproject.toml index 830088d..eb74a11 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -4,28 +4,22 @@ version = "0.1.0" description = "Shared models, events, and utilities for trading platform" requires-python = ">=3.12" dependencies = [ - "pydantic>=2.0", - "pydantic-settings>=2.0", - "redis>=5.0", - "asyncpg>=0.29", - "sqlalchemy[asyncio]>=2.0", - "alembic>=1.13", - "structlog>=24.0", - "prometheus-client>=0.20", - "pyyaml>=6.0", - "aiohttp>=3.9", - "rich>=13.0", + "pydantic>=2.8,<3", + "pydantic-settings>=2.0,<3", + "redis>=5.0,<6", + "asyncpg>=0.29,<1", + "sqlalchemy[asyncio]>=2.0,<3", + "alembic>=1.13,<2", + "structlog>=24.0,<25", + "prometheus-client>=0.20,<1", + "pyyaml>=6.0,<7", + "aiohttp>=3.9,<4", + "rich>=13.0,<14", ] [project.optional-dependencies] -dev = [ - "pytest>=8.0", - "pytest-asyncio>=0.23", - "ruff>=0.4", -] -claude = [ - "anthropic>=0.40", -] +dev = ["pytest>=8.0,<9", "pytest-asyncio>=0.23,<1", "ruff>=0.4,<1"] +claude = ["anthropic>=0.40,<1"] [build-system] requires = ["hatchling"] diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py index fbe4576..2b96714 100644 --- a/shared/src/shared/broker.py +++ b/shared/src/shared/broker.py @@ -5,13 +5,21 @@ from typing import Any import redis.asyncio +from shared.resilience import retry_async + class RedisBroker: """Async Redis Streams broker for publishing and reading events.""" def __init__(self, redis_url: str) -> None: - self._redis = redis.asyncio.from_url(redis_url) + self._redis = redis.asyncio.from_url( + redis_url, + socket_keepalive=True, + health_check_interval=30, + retry_on_timeout=True, + ) + @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,)) async def publish(self, stream: str, data: dict[str, Any]) -> None: """Publish a message to a Redis stream.""" payload = json.dumps(data) @@ -25,6 +33,7 @@ class RedisBroker: if "BUSYGROUP" not in str(e): raise + @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,)) async def read_group( self, stream: str, @@ -99,6 +108,7 @@ class RedisBroker: messages.append(json.loads(payload)) return messages + @retry_async(max_retries=2, base_delay=0.5) async def ping(self) -> bool: """Ping the Redis server; return True if reachable.""" return await self._redis.ping() diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index 4e8e7f1..0f1c66e 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -1,14 +1,18 @@ """Shared configuration settings for the trading platform.""" +from pydantic import SecretStr, field_validator from pydantic_settings import BaseSettings class Settings(BaseSettings): - alpaca_api_key: str = "" - alpaca_api_secret: str = "" + alpaca_api_key: SecretStr = SecretStr("") + alpaca_api_secret: SecretStr = SecretStr("") alpaca_paper: bool = True # Use paper trading by default - redis_url: str = "redis://localhost:6379" - database_url: str = "postgresql://trading:trading@localhost:5432/trading" + redis_url: SecretStr = SecretStr("redis://localhost:6379") + database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading") + db_pool_size: int = 20 + db_max_overflow: int = 10 + db_pool_recycle: int = 3600 log_level: str = "INFO" risk_max_position_size: float = 0.1 risk_stop_loss_pct: float = 5.0 @@ -27,12 +31,45 @@ class Settings(BaseSettings): risk_max_consecutive_losses: int = 5 risk_loss_pause_minutes: int = 60 dry_run: bool = True - telegram_bot_token: str = "" + telegram_bot_token: SecretStr = SecretStr("") telegram_chat_id: str = "" telegram_enabled: bool = False log_format: str = "json" health_port: int = 8080 - circuit_breaker_threshold: int = 5 - circuit_breaker_timeout: int = 60 metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token + # API security + api_auth_token: SecretStr = SecretStr("") + cors_origins: str = "http://localhost:3000" + # News collector + finnhub_api_key: SecretStr = SecretStr("") + news_poll_interval: int = 300 + sentiment_aggregate_interval: int = 900 + # Stock selector + selector_final_time: str = "15:30" + selector_max_picks: int = 3 + # LLM + anthropic_api_key: SecretStr = SecretStr("") + anthropic_model: str = "claude-sonnet-4-20250514" model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} + + @field_validator("risk_max_position_size") + @classmethod + def validate_position_size(cls, v: float) -> float: + if v <= 0 or v > 1: + raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)") + return v + + @field_validator("health_port") + @classmethod + def validate_health_port(cls, v: int) -> int: + if v < 1024 or v > 65535: + raise ValueError("health_port must be between 1024 and 65535") + return v + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + if v.upper() not in valid: + raise ValueError(f"log_level must be one of {valid}") + return v.upper() diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 901e293..8fee000 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,15 +1,27 @@ """Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" +import json +import uuid from contextlib import asynccontextmanager -from datetime import datetime, timedelta, timezone +from datetime import UTC, date, datetime, timedelta from decimal import Decimal -from typing import Optional from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from shared.models import Candle, Signal, Order, OrderStatus -from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow +from shared.models import Candle, NewsItem, Order, OrderStatus, Signal +from shared.sa_models import ( + Base, + CandleRow, + MarketSentimentRow, + NewsItemRow, + OrderRow, + PortfolioSnapshotRow, + SignalRow, + StockSelectionRow, + SymbolScoreRow, +) +from shared.sentiment_models import MarketSentiment, SymbolScore class Database: @@ -23,9 +35,24 @@ class Database: self._engine = None self._session_factory = None - async def connect(self) -> None: + async def connect( + self, + pool_size: int = 20, + max_overflow: int = 10, + pool_recycle: int = 3600, + ) -> None: """Create the async engine, session factory, and all tables.""" - self._engine = create_async_engine(self._database_url) + if self._database_url.startswith("sqlite"): + # SQLite doesn't support pooling options + self._engine = create_async_engine(self._database_url) + else: + self._engine = create_async_engine( + self._database_url, + pool_pre_ping=True, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + ) self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) async with self._engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -85,6 +112,9 @@ class Database: price=signal.price, quantity=signal.quantity, reason=signal.reason, + conviction=signal.conviction, + stop_loss=signal.stop_loss, + take_profit=signal.take_profit, created_at=signal.created_at, ) async with self._session_factory() as session: @@ -121,7 +151,7 @@ class Database: self, order_id: str, status: OrderStatus, - filled_at: Optional[datetime] = None, + filled_at: datetime | None = None, ) -> None: """Update the status (and optionally filled_at) of an order.""" stmt = ( @@ -167,7 +197,7 @@ class Database: total_value=total_value, realized_pnl=realized_pnl, unrealized_pnl=unrealized_pnl, - snapshot_at=datetime.now(timezone.utc), + snapshot_at=datetime.now(UTC), ) session.add(row) await session.commit() @@ -178,7 +208,7 @@ class Database: async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]: """Retrieve recent portfolio snapshots.""" async with self.get_session() as session: - since = datetime.now(timezone.utc) - timedelta(days=days) + since = datetime.now(UTC) - timedelta(days=days) stmt = ( select(PortfolioSnapshotRow) .where(PortfolioSnapshotRow.snapshot_at >= since) @@ -195,3 +225,229 @@ class Database: } for r in rows ] + + async def insert_news_item(self, item: NewsItem) -> None: + """Insert a NewsItem row, JSON-encoding symbols and raw_data.""" + row = NewsItemRow( + id=item.id, + source=item.source, + headline=item.headline, + summary=item.summary, + url=item.url, + published_at=item.published_at, + symbols=json.dumps(item.symbols), + sentiment=item.sentiment, + category=item.category.value, + raw_data=json.dumps(item.raw_data), + created_at=item.created_at, + ) + async with self._session_factory() as session: + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_recent_news(self, hours: int = 24) -> list[dict]: + """Retrieve news items published in the last N hours.""" + since = datetime.now(UTC) - timedelta(hours=hours) + stmt = ( + select(NewsItemRow) + .where(NewsItemRow.published_at >= since) + .order_by(NewsItemRow.published_at.desc()) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "source": r.source, + "headline": r.headline, + "summary": r.summary, + "url": r.url, + "published_at": r.published_at, + "symbols": json.loads(r.symbols) if r.symbols else [], + "sentiment": r.sentiment, + "category": r.category, + "raw_data": json.loads(r.raw_data) if r.raw_data else {}, + "created_at": r.created_at, + } + for r in rows + ] + + async def upsert_symbol_score(self, score: SymbolScore) -> None: + """Insert or update a SymbolScore row, keyed by symbol.""" + async with self._session_factory() as session: + try: + stmt = select(SymbolScoreRow).where(SymbolScoreRow.symbol == score.symbol) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + if existing is not None: + existing.news_score = score.news_score + existing.news_count = score.news_count + existing.social_score = score.social_score + existing.policy_score = score.policy_score + existing.filing_score = score.filing_score + existing.composite = score.composite + existing.updated_at = score.updated_at + else: + row = SymbolScoreRow( + id=str(uuid.uuid4()), + symbol=score.symbol, + news_score=score.news_score, + news_count=score.news_count, + social_score=score.social_score, + policy_score=score.policy_score, + filing_score=score.filing_score, + composite=score.composite, + updated_at=score.updated_at, + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_top_symbol_scores(self, limit: int = 20) -> list[dict]: + """Retrieve top symbol scores ordered by composite descending.""" + stmt = select(SymbolScoreRow).order_by(SymbolScoreRow.composite.desc()).limit(limit) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "symbol": r.symbol, + "news_score": r.news_score, + "news_count": r.news_count, + "social_score": r.social_score, + "policy_score": r.policy_score, + "filing_score": r.filing_score, + "composite": r.composite, + "updated_at": r.updated_at, + } + for r in rows + ] + + async def upsert_market_sentiment(self, ms: MarketSentiment) -> None: + """Insert or update the single 'latest' market sentiment row.""" + async with self._session_factory() as session: + try: + stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + if existing is not None: + existing.fear_greed = ms.fear_greed + existing.fear_greed_label = ms.fear_greed_label + existing.vix = ms.vix + existing.fed_stance = ms.fed_stance + existing.market_regime = ms.market_regime + existing.updated_at = ms.updated_at + else: + row = MarketSentimentRow( + id="latest", + fear_greed=ms.fear_greed, + fear_greed_label=ms.fear_greed_label, + vix=ms.vix, + fed_stance=ms.fed_stance, + market_regime=ms.market_regime, + updated_at=ms.updated_at, + ) + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_latest_market_sentiment(self) -> dict | None: + """Retrieve the 'latest' market sentiment row, or None if not found.""" + stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest") + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + row = result.scalar_one_or_none() + except Exception: + await session.rollback() + raise + if row is None: + return None + return { + "id": row.id, + "fear_greed": row.fear_greed, + "fear_greed_label": row.fear_greed_label, + "vix": row.vix, + "fed_stance": row.fed_stance, + "market_regime": row.market_regime, + "updated_at": row.updated_at, + } + + async def insert_stock_selection( + self, + trade_date: date, + symbol: str, + side: str, + conviction: float, + reason: str, + key_news: list, + sentiment_snapshot: dict, + ) -> None: + """Insert a stock selection row with JSON-encoded lists/dicts.""" + row = StockSelectionRow( + id=str(uuid.uuid4()), + trade_date=trade_date, + symbol=symbol, + side=side, + conviction=conviction, + reason=reason, + key_news=json.dumps(key_news), + sentiment_snapshot=json.dumps(sentiment_snapshot), + created_at=datetime.now(UTC), + ) + async with self._session_factory() as session: + try: + session.add(row) + await session.commit() + except Exception: + await session.rollback() + raise + + async def get_stock_selections(self, trade_date: date) -> list[dict]: + """Retrieve stock selections for a given trade date.""" + stmt = ( + select(StockSelectionRow) + .where(StockSelectionRow.trade_date == trade_date) + .order_by(StockSelectionRow.conviction.desc()) + ) + async with self._session_factory() as session: + try: + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + await session.rollback() + raise + return [ + { + "id": r.id, + "trade_date": r.trade_date, + "symbol": r.symbol, + "side": r.side, + "conviction": r.conviction, + "reason": r.reason, + "key_news": json.loads(r.key_news) if r.key_news else [], + "sentiment_snapshot": json.loads(r.sentiment_snapshot) + if r.sentiment_snapshot + else {}, + "created_at": r.created_at, + } + for r in rows + ] diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py index 72f8865..37217a0 100644 --- a/shared/src/shared/events.py +++ b/shared/src/shared/events.py @@ -1,17 +1,18 @@ """Event types and serialization for the trading platform.""" -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel -from shared.models import Candle, Signal, Order +from shared.models import Candle, NewsItem, Order, Signal -class EventType(str, Enum): +class EventType(StrEnum): CANDLE = "CANDLE" SIGNAL = "SIGNAL" ORDER = "ORDER" + NEWS = "NEWS" class CandleEvent(BaseModel): @@ -59,10 +60,26 @@ class OrderEvent(BaseModel): return cls(type=raw["type"], data=Order(**raw["data"])) +class NewsEvent(BaseModel): + type: EventType = EventType.NEWS + data: NewsItem + + def to_dict(self) -> dict: + return { + "type": self.type, + "data": self.data.model_dump(mode="json"), + } + + @classmethod + def from_raw(cls, raw: dict) -> "NewsEvent": + return cls(type=raw["type"], data=NewsItem(**raw["data"])) + + _EVENT_TYPE_MAP = { EventType.CANDLE: CandleEvent, EventType.SIGNAL: SignalEvent, EventType.ORDER: OrderEvent, + EventType.NEWS: NewsEvent, } @@ -71,6 +88,16 @@ class Event: @staticmethod def from_dict(data: dict) -> Any: - event_type = EventType(data["type"]) + """Deserialize a raw dict into the appropriate event type. + + Raises ValueError for malformed or unrecognized event data. + """ + try: + event_type = EventType(data["type"]) + except (KeyError, ValueError) as exc: + raise ValueError(f"Invalid or missing event type in data: {data!r}") from exc cls = _EVENT_TYPE_MAP[event_type] - return cls.from_raw(data) + try: + return cls.from_raw(data) + except KeyError as exc: + raise ValueError(f"Missing required field in {event_type} event data: {exc}") from exc diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py index 7411e8a..a19705b 100644 --- a/shared/src/shared/healthcheck.py +++ b/shared/src/shared/healthcheck.py @@ -3,10 +3,11 @@ from __future__ import annotations import time -from typing import Any, Callable, Awaitable +from collections.abc import Awaitable, Callable +from typing import Any from aiohttp import web -from prometheus_client import CollectorRegistry, REGISTRY, generate_latest, CONTENT_TYPE_LATEST +from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, CollectorRegistry, generate_latest class HealthCheckServer: diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py index cd239f3..6189143 100644 --- a/shared/src/shared/metrics.py +++ b/shared/src/shared/metrics.py @@ -2,7 +2,7 @@ from __future__ import annotations -from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge, Histogram class ServiceMetrics: diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py index 70820b5..f357f9f 100644 --- a/shared/src/shared/models.py +++ b/shared/src/shared/models.py @@ -1,25 +1,24 @@ """Shared Pydantic models for the trading platform.""" import uuid +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -from enum import Enum -from typing import Optional +from enum import StrEnum from pydantic import BaseModel, Field, computed_field -class OrderSide(str, Enum): +class OrderSide(StrEnum): BUY = "BUY" SELL = "SELL" -class OrderType(str, Enum): +class OrderType(StrEnum): MARKET = "MARKET" LIMIT = "LIMIT" -class OrderStatus(str, Enum): +class OrderStatus(StrEnum): PENDING = "PENDING" FILLED = "FILLED" CANCELLED = "CANCELLED" @@ -46,9 +45,9 @@ class Signal(BaseModel): quantity: Decimal reason: str conviction: float = 1.0 # 0.0 to 1.0, signal strength/confidence - stop_loss: Optional[Decimal] = None # Price to exit at loss - take_profit: Optional[Decimal] = None # Price to exit at profit - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + stop_loss: Decimal | None = None # Price to exit at loss + take_profit: Decimal | None = None # Price to exit at profit + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) class Order(BaseModel): @@ -60,8 +59,8 @@ class Order(BaseModel): price: Decimal quantity: Decimal status: OrderStatus = OrderStatus.PENDING - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - filled_at: Optional[datetime] = None + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + filled_at: datetime | None = None class Position(BaseModel): @@ -74,3 +73,26 @@ class Position(BaseModel): @property def unrealized_pnl(self) -> Decimal: return self.quantity * (self.current_price - self.avg_entry_price) + + +class NewsCategory(StrEnum): + POLICY = "policy" + EARNINGS = "earnings" + MACRO = "macro" + SOCIAL = "social" + FILING = "filing" + FED = "fed" + + +class NewsItem(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + source: str + headline: str + summary: str | None = None + url: str | None = None + published_at: datetime + symbols: list[str] = [] + sentiment: float + category: NewsCategory + raw_data: dict = {} + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/shared/src/shared/notifier.py b/shared/src/shared/notifier.py index f03919c..cfc86cd 100644 --- a/shared/src/shared/notifier.py +++ b/shared/src/shared/notifier.py @@ -2,12 +2,13 @@ import asyncio import logging +from collections.abc import Sequence from decimal import Decimal -from typing import Optional, Sequence import aiohttp -from shared.models import Signal, Order, Position +from shared.models import Order, Position, Signal +from shared.sentiment_models import MarketSentiment, SelectedStock logger = logging.getLogger(__name__) @@ -22,7 +23,7 @@ class TelegramNotifier: self._bot_token = bot_token self._chat_id = chat_id self._semaphore = asyncio.Semaphore(1) - self._session: Optional[aiohttp.ClientSession] = None + self._session: aiohttp.ClientSession | None = None @property def enabled(self) -> bool: @@ -112,17 +113,45 @@ class TelegramNotifier: "", "<b>Positions:</b>", ] - for pos in positions: - lines.append( - f" {pos.symbol}: qty={pos.quantity} " - f"entry={pos.avg_entry_price} " - f"current={pos.current_price} " - f"pnl={pos.unrealized_pnl}" - ) + lines.extend( + f" {pos.symbol}: qty={pos.quantity} " + f"entry={pos.avg_entry_price} " + f"current={pos.current_price} " + f"pnl={pos.unrealized_pnl}" + for pos in positions + ) if not positions: lines.append(" No open positions") await self.send("\n".join(lines)) + async def send_stock_selection( + self, + selections: list[SelectedStock], + market: MarketSentiment | None = None, + ) -> None: + """Format and send stock selection notification.""" + lines = [f"<b>📊 Stock Selection ({len(selections)} picks)</b>", ""] + + side_emoji = {"BUY": "🟢", "SELL": "🔴"} + + for i, s in enumerate(selections, 1): + emoji = side_emoji.get(s.side.value, "⚪") + lines.append( + f"{i}. <b>{s.symbol}</b> {emoji} {s.side.value} (conviction: {s.conviction:.0%})" + ) + lines.append(f" {s.reason}") + if s.key_news: + lines.append(f" News: {s.key_news[0]}") + lines.append("") + + if market: + lines.append( + f"Market: F&G {market.fear_greed} ({market.fear_greed_label})" + + (f" | VIX {market.vix:.1f}" if market.vix else "") + ) + + await self.send("\n".join(lines)) + async def close(self) -> None: """Close the underlying aiohttp session.""" if self._session is not None: diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py index e43fd21..66225d7 100644 --- a/shared/src/shared/resilience.py +++ b/shared/src/shared/resilience.py @@ -1,29 +1,45 @@ -"""Retry with exponential backoff and circuit breaker utilities.""" +"""Resilience utilities for the trading platform. + +Provides retry, circuit breaker, and timeout primitives using only stdlib. +No external dependencies required. +""" from __future__ import annotations import asyncio -import enum import functools import logging import random import time -from typing import Any, Callable +from collections.abc import Callable +from contextlib import asynccontextmanager +from enum import StrEnum +from typing import Any -logger = logging.getLogger(__name__) + +class _State(StrEnum): + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" -# --------------------------------------------------------------------------- -# retry_with_backoff -# --------------------------------------------------------------------------- +logger = logging.getLogger(__name__) -def retry_with_backoff( +def retry_async( max_retries: int = 3, base_delay: float = 1.0, - max_delay: float = 60.0, + max_delay: float = 30.0, + exclude: tuple[type[BaseException], ...] = (), ) -> Callable: - """Decorator that retries an async function with exponential backoff + jitter.""" + """Decorator: exponential backoff + jitter for async functions. + + Parameters: + max_retries: Maximum number of retry attempts (after the initial call). + base_delay: Base delay in seconds for exponential backoff. + max_delay: Maximum delay cap in seconds. + exclude: Exception types that should NOT be retried (raised immediately). + """ def decorator(func: Callable) -> Callable: @functools.wraps(func) @@ -33,20 +49,21 @@ def retry_with_backoff( try: return await func(*args, **kwargs) except Exception as exc: + if exclude and isinstance(exc, exclude): + raise last_exc = exc if attempt < max_retries: delay = min(base_delay * (2**attempt), max_delay) - jitter = delay * random.uniform(0, 0.5) - total_delay = delay + jitter + jitter_delay = delay * random.uniform(0.5, 1.0) logger.warning( - "Retry %d/%d for %s after error: %s (delay=%.3fs)", + "Retry %d/%d for %s in %.2fs: %s", attempt + 1, max_retries, func.__name__, + jitter_delay, exc, - total_delay, ) - await asyncio.sleep(total_delay) + await asyncio.sleep(jitter_delay) raise last_exc # type: ignore[misc] return wrapper @@ -54,52 +71,65 @@ def retry_with_backoff( return decorator -# --------------------------------------------------------------------------- -# CircuitBreaker -# --------------------------------------------------------------------------- - - -class CircuitState(enum.Enum): - CLOSED = "closed" - OPEN = "open" - HALF_OPEN = "half_open" +class CircuitBreaker: + """Circuit breaker: opens after N consecutive failures, auto-recovers. + States: closed -> open -> half_open -> closed -class CircuitBreaker: - """Simple circuit breaker implementation.""" + Parameters: + failure_threshold: Number of consecutive failures before opening. + cooldown: Seconds to wait before allowing a half-open probe. + """ - def __init__( - self, - failure_threshold: int = 5, - recovery_timeout: float = 60.0, - ) -> None: + def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None: self._failure_threshold = failure_threshold - self._recovery_timeout = recovery_timeout - self._failure_count: int = 0 - self._state = CircuitState.CLOSED + self._cooldown = cooldown + self._failures = 0 + self._state = _State.CLOSED self._opened_at: float = 0.0 - @property - def state(self) -> CircuitState: - return self._state - - def allow_request(self) -> bool: - if self._state == CircuitState.CLOSED: - return True - if self._state == CircuitState.OPEN: - if time.monotonic() - self._opened_at >= self._recovery_timeout: - self._state = CircuitState.HALF_OPEN - return True - return False - # HALF_OPEN - return True - - def record_success(self) -> None: - self._failure_count = 0 - self._state = CircuitState.CLOSED - - def record_failure(self) -> None: - self._failure_count += 1 - if self._failure_count >= self._failure_threshold: - self._state = CircuitState.OPEN - self._opened_at = time.monotonic() + async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any: + """Execute func through the breaker.""" + if self._state == _State.OPEN: + if time.monotonic() - self._opened_at >= self._cooldown: + self._state = _State.HALF_OPEN + else: + raise RuntimeError("Circuit breaker is open") + + try: + result = await func(*args, **kwargs) + except Exception: + self._failures += 1 + if self._state == _State.HALF_OPEN: + self._state = _State.OPEN + self._opened_at = time.monotonic() + logger.error( + "Circuit breaker re-opened after half-open probe failure (threshold=%d)", + self._failure_threshold, + ) + elif self._failures >= self._failure_threshold: + self._state = _State.OPEN + self._opened_at = time.monotonic() + logger.error( + "Circuit breaker opened after %d consecutive failures", + self._failures, + ) + raise + + # Success: reset + self._failures = 0 + self._state = _State.CLOSED + return result + + +@asynccontextmanager +async def async_timeout(seconds: float): + """Async context manager wrapping asyncio.timeout(). + + Raises TimeoutError with a descriptive message on timeout. + """ + try: + async with asyncio.timeout(seconds): + yield + except TimeoutError: + raise TimeoutError(f"Operation timed out after {seconds}s") from None diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py index 8386ba8..b70a6c4 100644 --- a/shared/src/shared/sa_models.py +++ b/shared/src/shared/sa_models.py @@ -3,6 +3,7 @@ from datetime import datetime from decimal import Decimal +import sqlalchemy as sa from sqlalchemy import DateTime, ForeignKey, Numeric, Text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -34,6 +35,9 @@ class SignalRow(Base): price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) reason: Mapped[str | None] = mapped_column(Text) + conviction: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="1.0") + stop_loss: Mapped[Decimal | None] = mapped_column(Numeric) + take_profit: Mapped[Decimal | None] = mapped_column(Numeric) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) @@ -83,3 +87,63 @@ class PortfolioSnapshotRow(Base): realized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False) unrealized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False) snapshot_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class NewsItemRow(Base): + __tablename__ = "news_items" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + source: Mapped[str] = mapped_column(Text, nullable=False) + headline: Mapped[str] = mapped_column(Text, nullable=False) + summary: Mapped[str | None] = mapped_column(Text) + url: Mapped[str | None] = mapped_column(Text) + published_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + symbols: Mapped[str | None] = mapped_column(Text) # JSON-encoded list + sentiment: Mapped[float] = mapped_column(sa.Float, nullable=False) + category: Mapped[str] = mapped_column(Text, nullable=False) + raw_data: Mapped[str | None] = mapped_column(Text) # JSON string + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ) + + +class SymbolScoreRow(Base): + __tablename__ = "symbol_scores" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + symbol: Mapped[str] = mapped_column(Text, nullable=False, unique=True) + news_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0") + news_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="0") + social_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0") + policy_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0") + filing_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0") + composite: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0") + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class MarketSentimentRow(Base): + __tablename__ = "market_sentiment" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + fear_greed: Mapped[int] = mapped_column(sa.Integer, nullable=False) + fear_greed_label: Mapped[str] = mapped_column(Text, nullable=False) + vix: Mapped[float | None] = mapped_column(sa.Float) + fed_stance: Mapped[str] = mapped_column(Text, nullable=False, server_default="neutral") + market_regime: Mapped[str] = mapped_column(Text, nullable=False, server_default="neutral") + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class StockSelectionRow(Base): + __tablename__ = "stock_selections" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + trade_date: Mapped[datetime] = mapped_column(sa.Date, nullable=False) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + conviction: Mapped[float] = mapped_column(sa.Float, nullable=False) + reason: Mapped[str] = mapped_column(Text, nullable=False) + key_news: Mapped[str | None] = mapped_column(Text) # JSON string + sentiment_snapshot: Mapped[str | None] = mapped_column(Text) # JSON string + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ) diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py index 8213b47..c56da3e 100644 --- a/shared/src/shared/sentiment.py +++ b/shared/src/shared/sentiment.py @@ -1,35 +1,106 @@ -"""Market sentiment data.""" - -import logging -from dataclasses import dataclass, field -from datetime import datetime, timezone - -logger = logging.getLogger(__name__) - - -@dataclass -class SentimentData: - """Aggregated sentiment snapshot.""" - - fear_greed_value: int | None = None - fear_greed_label: str | None = None - news_sentiment: float | None = None - news_count: int = 0 - exchange_netflow: float | None = None - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - @property - def should_buy(self) -> bool: - if self.fear_greed_value is not None and self.fear_greed_value > 70: - return False - if self.news_sentiment is not None and self.news_sentiment < -0.3: - return False - return True - - @property - def should_block(self) -> bool: - if self.fear_greed_value is not None and self.fear_greed_value > 80: - return True - if self.news_sentiment is not None and self.news_sentiment < -0.5: - return True - return False +"""Market sentiment aggregation.""" + +from datetime import datetime +from typing import ClassVar + +from shared.sentiment_models import SymbolScore + + +def _safe_avg(values: list[float]) -> float: + if not values: + return 0.0 + return sum(values) / len(values) + + +class SentimentAggregator: + """Aggregates per-news sentiment into per-symbol scores.""" + + WEIGHTS: ClassVar[dict[str, float]] = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2} + + CATEGORY_MAP: ClassVar[dict[str, str]] = { + "earnings": "news", + "macro": "news", + "social": "social", + "policy": "policy", + "filing": "filing", + "fed": "policy", + } + + def _freshness_decay(self, published_at: datetime, now: datetime) -> float: + age = now - published_at + hours = age.total_seconds() / 3600 + if hours < 1: + return 1.0 + if hours < 6: + return 0.7 + if hours < 24: + return 0.3 + return 0.0 + + def _compute_composite( + self, + news_score: float, + social_score: float, + policy_score: float, + filing_score: float, + ) -> float: + return ( + news_score * self.WEIGHTS["news"] + + social_score * self.WEIGHTS["social"] + + policy_score * self.WEIGHTS["policy"] + + filing_score * self.WEIGHTS["filing"] + ) + + def aggregate(self, news_items: list[dict], now: datetime) -> dict[str, SymbolScore]: + """Aggregate news items into per-symbol scores. + + Each dict needs: symbols, sentiment, category, published_at. + """ + symbol_data: dict[str, dict] = {} + + for item in news_items: + decay = self._freshness_decay(item["published_at"], now) + if decay == 0.0: + continue + category = item.get("category", "macro") + score_field = self.CATEGORY_MAP.get(category, "news") + weighted_sentiment = item["sentiment"] * decay + + for symbol in item.get("symbols", []): + if symbol not in symbol_data: + symbol_data[symbol] = { + "news_scores": [], + "social_scores": [], + "policy_scores": [], + "filing_scores": [], + "count": 0, + } + symbol_data[symbol][f"{score_field}_scores"].append(weighted_sentiment) + symbol_data[symbol]["count"] += 1 + + result = {} + for symbol, data in symbol_data.items(): + ns = _safe_avg(data["news_scores"]) + ss = _safe_avg(data["social_scores"]) + ps = _safe_avg(data["policy_scores"]) + fs = _safe_avg(data["filing_scores"]) + result[symbol] = SymbolScore( + symbol=symbol, + news_score=ns, + news_count=data["count"], + social_score=ss, + policy_score=ps, + filing_score=fs, + composite=self._compute_composite(ns, ss, ps, fs), + updated_at=now, + ) + return result + + def determine_regime(self, fear_greed: int, vix: float | None) -> str: + if fear_greed <= 20: + return "risk_off" + if vix is not None and vix > 30: + return "risk_off" + if fear_greed >= 60 and (vix is None or vix < 20): + return "risk_on" + return "neutral" diff --git a/shared/src/shared/sentiment_models.py b/shared/src/shared/sentiment_models.py new file mode 100644 index 0000000..ac06c20 --- /dev/null +++ b/shared/src/shared/sentiment_models.py @@ -0,0 +1,43 @@ +"""Sentiment scoring and stock selection models.""" + +from datetime import datetime + +from pydantic import BaseModel + +from shared.models import OrderSide + + +class SymbolScore(BaseModel): + symbol: str + news_score: float + news_count: int + social_score: float + policy_score: float + filing_score: float + composite: float + updated_at: datetime + + +class MarketSentiment(BaseModel): + fear_greed: int + fear_greed_label: str + vix: float | None = None + fed_stance: str + market_regime: str + updated_at: datetime + + +class SelectedStock(BaseModel): + symbol: str + side: OrderSide + conviction: float + reason: str + key_news: list[str] + + +class Candidate(BaseModel): + symbol: str + source: str + direction: OrderSide | None = None + score: float + reason: str diff --git a/shared/src/shared/shutdown.py b/shared/src/shared/shutdown.py new file mode 100644 index 0000000..4ed9aa7 --- /dev/null +++ b/shared/src/shared/shutdown.py @@ -0,0 +1,30 @@ +"""Graceful shutdown utilities for services.""" + +import asyncio +import logging +import signal + +logger = logging.getLogger(__name__) + + +class GracefulShutdown: + """Manages graceful shutdown via SIGTERM/SIGINT signals.""" + + def __init__(self) -> None: + self._event = asyncio.Event() + + @property + def is_shutting_down(self) -> bool: + return self._event.is_set() + + async def wait(self) -> None: + await self._event.wait() + + def trigger(self) -> None: + logger.info("shutdown_signal_received") + self._event.set() + + def install_handlers(self) -> None: + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, self.trigger) diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py index 080b7c4..55a2b24 100644 --- a/shared/tests/test_alpaca.py +++ b/shared/tests/test_alpaca.py @@ -1,7 +1,9 @@ """Tests for Alpaca API client.""" -import pytest from unittest.mock import AsyncMock, MagicMock + +import pytest + from shared.alpaca import AlpacaClient diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py index 9be84b0..5636611 100644 --- a/shared/tests/test_broker.py +++ b/shared/tests/test_broker.py @@ -1,10 +1,11 @@ """Tests for the Redis broker.""" -import pytest import json -import redis from unittest.mock import AsyncMock, patch +import pytest +import redis + @pytest.mark.asyncio async def test_broker_publish(): @@ -16,7 +17,7 @@ async def test_broker_publish(): from shared.broker import RedisBroker broker = RedisBroker("redis://localhost:6379") - data = {"type": "CANDLE", "symbol": "BTCUSDT"} + data = {"type": "CANDLE", "symbol": "AAPL"} await broker.publish("candles", data) mock_redis.xadd.assert_called_once() @@ -35,7 +36,7 @@ async def test_broker_subscribe_returns_messages(): mock_redis = AsyncMock() mock_from_url.return_value = mock_redis - payload_data = {"type": "CANDLE", "symbol": "ETHUSDT"} + payload_data = {"type": "CANDLE", "symbol": "MSFT"} mock_redis.xread.return_value = [ [ b"candles", @@ -53,7 +54,7 @@ async def test_broker_subscribe_returns_messages(): mock_redis.xread.assert_called_once() assert len(messages) == 1 assert messages[0]["type"] == "CANDLE" - assert messages[0]["symbol"] == "ETHUSDT" + assert messages[0]["symbol"] == "MSFT" @pytest.mark.asyncio diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py new file mode 100644 index 0000000..9376dc6 --- /dev/null +++ b/shared/tests/test_config_validation.py @@ -0,0 +1,29 @@ +"""Tests for config validation.""" + +import pytest +from pydantic import ValidationError + +from shared.config import Settings + + +class TestConfigValidation: + def test_valid_defaults(self): + settings = Settings() + assert settings.risk_max_position_size == 0.1 + + def test_invalid_position_size(self): + with pytest.raises(ValidationError, match="risk_max_position_size"): + Settings(risk_max_position_size=-0.1) + + def test_invalid_health_port(self): + with pytest.raises(ValidationError, match="health_port"): + Settings(health_port=80) + + def test_invalid_log_level(self): + with pytest.raises(ValidationError, match="log_level"): + Settings(log_level="INVALID") + + def test_secret_fields_masked(self): + settings = Settings(alpaca_api_key="my-secret-key") + assert "my-secret-key" not in repr(settings) + assert settings.alpaca_api_key.get_secret_value() == "my-secret-key" diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 239ee64..b44a713 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,10 +1,11 @@ """Tests for the SQLAlchemy async database layer.""" -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch +import pytest + def make_candle(): from shared.models import Candle @@ -12,7 +13,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -22,7 +23,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( id="sig-1", @@ -32,12 +33,12 @@ def make_signal(): price=Decimal("50000"), quantity=Decimal("0.1"), reason="Golden cross", - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) def make_order(): - from shared.models import Order, OrderSide, OrderType, OrderStatus + from shared.models import Order, OrderSide, OrderStatus, OrderType return Order( id="ord-1", @@ -48,7 +49,7 @@ def make_order(): price=Decimal("50000"), quantity=Decimal("0.1"), status=OrderStatus.PENDING, - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) @@ -101,6 +102,54 @@ class TestDatabaseConnect: mock_create.assert_called_once() @pytest.mark.asyncio + async def test_connect_passes_pool_params_for_postgres(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800) + mock_create.assert_called_once_with( + "postgresql+asyncpg://host/db", + pool_pre_ping=True, + pool_size=5, + max_overflow=3, + pool_recycle=1800, + ) + + @pytest.mark.asyncio + async def test_connect_skips_pool_params_for_sqlite(self): + from shared.db import Database + + db = Database("sqlite+aiosqlite:///test.db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db") + + @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database @@ -211,7 +260,7 @@ class TestUpdateOrderStatus: db._session_factory = MagicMock(return_value=mock_session) - filled = datetime(2024, 1, 2, tzinfo=timezone.utc) + filled = datetime(2024, 1, 2, tzinfo=UTC) await db.update_order_status("ord-1", OrderStatus.FILLED, filled) mock_session.execute.assert_awaited_once() @@ -230,7 +279,7 @@ class TestGetCandles: mock_row._mapping = { "symbol": "AAPL", "timeframe": "1m", - "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), + "open_time": datetime(2024, 1, 1, tzinfo=UTC), "open": Decimal("50000"), "high": Decimal("51000"), "low": Decimal("49500"), @@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots: mock_row.total_value = Decimal("10000") mock_row.realized_pnl = Decimal("0") mock_row.unrealized_pnl = Decimal("500") - mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_row] diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py new file mode 100644 index 0000000..c184bed --- /dev/null +++ b/shared/tests/test_db_news.py @@ -0,0 +1,79 @@ +"""Tests for database news/sentiment methods. Uses in-memory SQLite.""" + +from datetime import UTC, date, datetime + +import pytest + +from shared.db import Database +from shared.models import NewsCategory, NewsItem +from shared.sentiment_models import MarketSentiment, SymbolScore + + +@pytest.fixture +async def db(): + database = Database("sqlite+aiosqlite://") + await database.connect() + yield database + await database.close() + + +async def test_insert_and_get_news_items(db): + item = NewsItem( + source="finnhub", + headline="AAPL earnings beat", + published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), + sentiment=0.8, + category=NewsCategory.EARNINGS, + symbols=["AAPL"], + ) + await db.insert_news_item(item) + items = await db.get_recent_news(hours=24) + assert len(items) == 1 + assert items[0]["headline"] == "AAPL earnings beat" + + +async def test_upsert_symbol_score(db): + score = SymbolScore( + symbol="AAPL", + news_score=0.5, + news_count=10, + social_score=0.3, + policy_score=0.0, + filing_score=0.2, + composite=0.3, + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + await db.upsert_symbol_score(score) + scores = await db.get_top_symbol_scores(limit=5) + assert len(scores) == 1 + assert scores[0]["symbol"] == "AAPL" + + +async def test_upsert_market_sentiment(db): + ms = MarketSentiment( + fear_greed=55, + fear_greed_label="Neutral", + vix=18.2, + fed_stance="neutral", + market_regime="neutral", + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + await db.upsert_market_sentiment(ms) + result = await db.get_latest_market_sentiment() + assert result is not None + assert result["fear_greed"] == 55 + + +async def test_insert_stock_selection(db): + await db.insert_stock_selection( + trade_date=date(2026, 4, 2), + symbol="NVDA", + side="BUY", + conviction=0.85, + reason="CHIPS Act", + key_news=["Trump signs CHIPS expansion"], + sentiment_snapshot={"composite": 0.8}, + ) + selections = await db.get_stock_selections(date(2026, 4, 2)) + assert len(selections) == 1 + assert selections[0]["symbol"] == "NVDA" diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py index 6077d93..1ccd904 100644 --- a/shared/tests/test_events.py +++ b/shared/tests/test_events.py @@ -1,7 +1,7 @@ """Tests for shared event types.""" +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone def make_candle(): @@ -10,7 +10,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -20,7 +20,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( strategy="test", @@ -59,7 +59,7 @@ def test_candle_event_deserialize(): def test_signal_event_serialize(): """Test SignalEvent serializes to dict correctly.""" - from shared.events import SignalEvent, EventType + from shared.events import EventType, SignalEvent signal = make_signal() event = SignalEvent(data=signal) @@ -71,7 +71,7 @@ def test_signal_event_serialize(): def test_event_from_dict_dispatch(): """Test Event.from_dict dispatches to correct class.""" - from shared.events import Event, CandleEvent, SignalEvent + from shared.events import CandleEvent, Event, SignalEvent candle = make_candle() event = CandleEvent(data=candle) diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index 04098ce..40bb791 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -1,8 +1,8 @@ """Tests for shared models and settings.""" import os +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import patch @@ -12,8 +12,11 @@ def test_settings_defaults(): with patch.dict(os.environ, {}, clear=False): settings = Settings() - assert settings.redis_url == "redis://localhost:6379" - assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" + assert settings.redis_url.get_secret_value() == "redis://localhost:6379" + assert ( + settings.database_url.get_secret_value() + == "postgresql://trading:trading@localhost:5432/trading" + ) assert settings.log_level == "INFO" assert settings.risk_max_position_size == 0.1 assert settings.risk_stop_loss_pct == 5.0 @@ -25,7 +28,7 @@ def test_candle_creation(): """Test Candle model creation.""" from shared.models import Candle - now = datetime.now(timezone.utc) + now = datetime.now(UTC) candle = Candle( symbol="AAPL", timeframe="1m", @@ -47,7 +50,7 @@ def test_candle_creation(): def test_signal_creation(): """Test Signal model creation.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi_strategy", @@ -69,9 +72,10 @@ def test_signal_creation(): def test_order_creation(): """Test Order model creation with defaults.""" - from shared.models import Order, OrderSide, OrderType, OrderStatus import uuid + from shared.models import Order, OrderSide, OrderStatus, OrderType + signal_id = str(uuid.uuid4()) order = Order( signal_id=signal_id, @@ -90,7 +94,7 @@ def test_order_creation(): def test_signal_conviction_default(): """Test Signal defaults for conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", @@ -107,7 +111,7 @@ def test_signal_conviction_default(): def test_signal_with_stops(): """Test Signal with explicit conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py new file mode 100644 index 0000000..f748d8a --- /dev/null +++ b/shared/tests/test_news_events.py @@ -0,0 +1,56 @@ +"""Tests for NewsEvent.""" + +from datetime import UTC, datetime + +from shared.events import Event, EventType, NewsEvent +from shared.models import NewsCategory, NewsItem + + +def test_news_event_to_dict(): + item = NewsItem( + source="finnhub", + headline="Test", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + event = NewsEvent(data=item) + d = event.to_dict() + assert d["type"] == EventType.NEWS + assert d["data"]["source"] == "finnhub" + + +def test_news_event_from_raw(): + raw = { + "type": "NEWS", + "data": { + "id": "abc", + "source": "rss", + "headline": "Test headline", + "published_at": "2026-04-02T00:00:00+00:00", + "sentiment": 0.3, + "category": "earnings", + "symbols": ["AAPL"], + "raw_data": {}, + }, + } + event = NewsEvent.from_raw(raw) + assert event.data.source == "rss" + assert event.data.symbols == ["AAPL"] + + +def test_event_dispatcher_news(): + raw = { + "type": "NEWS", + "data": { + "id": "abc", + "source": "finnhub", + "headline": "Test", + "published_at": "2026-04-02T00:00:00+00:00", + "sentiment": 0.0, + "category": "macro", + "raw_data": {}, + }, + } + event = Event.from_dict(raw) + assert isinstance(event, NewsEvent) diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py index 3d29830..cc98a56 100644 --- a/shared/tests/test_notifier.py +++ b/shared/tests/test_notifier.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position +from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal from shared.notifier import TelegramNotifier @@ -86,7 +86,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") signal = Signal( strategy="rsi_strategy", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000.00"), quantity=Decimal("0.01"), @@ -99,7 +99,7 @@ class TestTelegramNotifierFormatters: msg = mock_send.call_args[0][0] assert "BUY" in msg assert "rsi_strategy" in msg - assert "BTCUSDT" in msg + assert "AAPL" in msg assert "50000.00" in msg assert "0.01" in msg assert "RSI oversold" in msg @@ -109,7 +109,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") order = Order( signal_id=str(uuid.uuid4()), - symbol="ETHUSDT", + symbol="MSFT", side=OrderSide.SELL, type=OrderType.LIMIT, price=Decimal("3000.50"), @@ -122,7 +122,7 @@ class TestTelegramNotifierFormatters: mock_send.assert_called_once() msg = mock_send.call_args[0][0] assert "FILLED" in msg - assert "ETHUSDT" in msg + assert "MSFT" in msg assert "SELL" in msg assert "3000.50" in msg assert "1.5" in msg @@ -143,7 +143,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") positions = [ Position( - symbol="BTCUSDT", + symbol="AAPL", quantity=Decimal("0.1"), avg_entry_price=Decimal("50000"), current_price=Decimal("51000"), @@ -158,7 +158,7 @@ class TestTelegramNotifierFormatters: ) mock_send.assert_called_once() msg = mock_send.call_args[0][0] - assert "BTCUSDT" in msg + assert "AAPL" in msg assert "5100.00" in msg assert "100.00" in msg diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index e287777..e0781af 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -1,139 +1,176 @@ -"""Tests for retry with backoff and circuit breaker.""" +"""Tests for shared.resilience module.""" -import time +import asyncio import pytest -from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff +from shared.resilience import CircuitBreaker, async_timeout, retry_async +# --- retry_async tests --- -# --------------------------------------------------------------------------- -# retry_with_backoff tests -# --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_retry_succeeds_first_try(): +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def succeed(): + @retry_async() + async def fn(): nonlocal call_count call_count += 1 return "ok" - result = await succeed() + result = await fn() assert result == "ok" assert call_count == 1 -@pytest.mark.asyncio -async def test_retry_succeeds_after_failures(): +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def flaky(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 if call_count < 3: - raise ValueError("not yet") + raise RuntimeError("transient") return "recovered" - result = await flaky() + result = await fn() assert result == "recovered" assert call_count == 3 -@pytest.mark.asyncio -async def test_retry_raises_after_max_retries(): +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def always_fail(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 - raise RuntimeError("permanent") + raise ValueError("permanent") - with pytest.raises(RuntimeError, match="permanent"): - await always_fail() - # 1 initial + 3 retries = 4 calls + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls assert call_count == 4 -@pytest.mark.asyncio -async def test_retry_respects_max_delay(): - """Backoff should be capped at max_delay.""" +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 - @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) - async def always_fail(): - raise RuntimeError("fail") + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") - start = time.monotonic() - with pytest.raises(RuntimeError): - await always_fail() - elapsed = time.monotonic() - start - # With max_delay=0.02 and 2 retries, total delay should be small - assert elapsed < 0.5 + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 -# --------------------------------------------------------------------------- -# CircuitBreaker tests -# --------------------------------------------------------------------------- +# --- CircuitBreaker tests --- -def test_circuit_starts_closed(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") -def test_circuit_opens_after_threshold(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) for _ in range(3): - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +async def test_half_open_reopens_on_failure(): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def always_fail(): + raise ConnectionError("fail") -def test_circuit_rejects_when_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(always_fail) + # Wait for cooldown + await asyncio.sleep(0.1) -def test_circuit_half_open_after_timeout(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN + # Half-open probe should fail and re-open + with pytest.raises(ConnectionError): + await cb.call(always_fail) - time.sleep(0.06) - assert cb.allow_request() is True - assert cb.state == CircuitState.HALF_OPEN + # Should be open again (no cooldown wait) + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(always_fail) -def test_circuit_closes_on_success(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN +# --- async_timeout tests --- - time.sleep(0.06) - cb.allow_request() # triggers HALF_OPEN - assert cb.state == CircuitState.HALF_OPEN - cb.record_success() - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 -def test_circuit_reopens_on_failure_in_half_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - time.sleep(0.06) - cb.allow_request() # HALF_OPEN - cb.record_failure() - assert cb.state == CircuitState.OPEN +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py index 67c3c82..c9311dd 100644 --- a/shared/tests/test_sa_models.py +++ b/shared/tests/test_sa_models.py @@ -14,6 +14,10 @@ def test_base_metadata_has_all_tables(): "trades", "positions", "portfolio_snapshots", + "news_items", + "symbol_scores", + "market_sentiment", + "stock_selections", } assert expected == table_names @@ -68,6 +72,9 @@ class TestSignalRow: "price", "quantity", "reason", + "conviction", + "stop_loss", + "take_profit", "created_at", } assert expected == cols @@ -120,44 +127,6 @@ class TestOrderRow: assert fk_cols == {"signal_id": "signals.id"} -class TestTradeRow: - def test_table_name(self): - from shared.sa_models import TradeRow - - assert TradeRow.__tablename__ == "trades" - - def test_columns(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - cols = {c.key for c in mapper.column_attrs} - expected = { - "id", - "order_id", - "symbol", - "side", - "price", - "quantity", - "fee", - "traded_at", - } - assert expected == cols - - def test_primary_key(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - pk_cols = [c.name for c in mapper.mapper.primary_key] - assert pk_cols == ["id"] - - def test_order_id_foreign_key(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys} - assert fk_cols == {"order_id": "orders.id"} - - class TestPositionRow: def test_table_name(self): from shared.sa_models import PositionRow @@ -229,11 +198,3 @@ class TestStatusDefault: status_col = table.c.status assert status_col.server_default is not None assert status_col.server_default.arg == "PENDING" - - def test_trade_fee_server_default(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fee_col = table.c.fee - assert fee_col.server_default is not None - assert fee_col.server_default.arg == "0" diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py new file mode 100644 index 0000000..dc2d026 --- /dev/null +++ b/shared/tests/test_sa_news_models.py @@ -0,0 +1,29 @@ +"""Tests for news-related SQLAlchemy models.""" + +from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow + + +def test_news_item_row_tablename(): + assert NewsItemRow.__tablename__ == "news_items" + + +def test_symbol_score_row_tablename(): + assert SymbolScoreRow.__tablename__ == "symbol_scores" + + +def test_market_sentiment_row_tablename(): + assert MarketSentimentRow.__tablename__ == "market_sentiment" + + +def test_stock_selection_row_tablename(): + assert StockSelectionRow.__tablename__ == "stock_selections" + + +def test_news_item_row_columns(): + cols = {c.name for c in NewsItemRow.__table__.columns} + assert cols >= {"id", "source", "headline", "published_at", "sentiment", "category"} + + +def test_symbol_score_row_columns(): + cols = {c.name for c in SymbolScoreRow.__table__.columns} + assert cols >= {"id", "symbol", "news_score", "composite", "updated_at"} diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py deleted file mode 100644 index 9bd8ea3..0000000 --- a/shared/tests/test_sentiment.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Tests for market sentiment module.""" - -from shared.sentiment import SentimentData - - -def test_sentiment_should_buy_default_no_data(): - s = SentimentData() - assert s.should_buy is True - assert s.should_block is False - - -def test_sentiment_should_buy_low_fear_greed(): - s = SentimentData(fear_greed_value=15) - assert s.should_buy is True - - -def test_sentiment_should_not_buy_on_greed(): - s = SentimentData(fear_greed_value=75) - assert s.should_buy is False - - -def test_sentiment_should_not_buy_negative_news(): - s = SentimentData(news_sentiment=-0.4) - assert s.should_buy is False - - -def test_sentiment_should_buy_positive_news(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.3) - assert s.should_buy is True - - -def test_sentiment_should_block_extreme_greed(): - s = SentimentData(fear_greed_value=85) - assert s.should_block is True - - -def test_sentiment_should_block_very_negative_news(): - s = SentimentData(news_sentiment=-0.6) - assert s.should_block is True - - -def test_sentiment_no_block_on_neutral(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.0) - assert s.should_block is False diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py new file mode 100644 index 0000000..9193785 --- /dev/null +++ b/shared/tests/test_sentiment_aggregator.py @@ -0,0 +1,79 @@ +"""Tests for sentiment aggregator.""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from shared.sentiment import SentimentAggregator + + +@pytest.fixture +def aggregator(): + return SentimentAggregator() + + +def test_freshness_decay_recent(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now, now) == 1.0 + + +def test_freshness_decay_3_hours(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7 + + +def test_freshness_decay_12_hours(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3 + + +def test_freshness_decay_old(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now - timedelta(days=2), now) == 0.0 + + +def test_compute_composite(): + a = SentimentAggregator() + composite = a._compute_composite( + news_score=0.5, social_score=0.3, policy_score=0.8, filing_score=0.2 + ) + expected = 0.5 * 0.3 + 0.3 * 0.2 + 0.8 * 0.3 + 0.2 * 0.2 + assert abs(composite - expected) < 0.001 + + +def test_aggregate_news_by_symbol(aggregator): + now = datetime.now(UTC) + news_items = [ + {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now}, + { + "symbols": ["AAPL"], + "sentiment": 0.3, + "category": "macro", + "published_at": now - timedelta(hours=2), + }, + {"symbols": ["MSFT"], "sentiment": -0.5, "category": "policy", "published_at": now}, + ] + scores = aggregator.aggregate(news_items, now) + assert "AAPL" in scores + assert "MSFT" in scores + assert scores["AAPL"].news_count == 2 + assert scores["AAPL"].news_score > 0 + assert scores["MSFT"].policy_score < 0 + + +def test_aggregate_empty(aggregator): + now = datetime.now(UTC) + assert aggregator.aggregate([], now) == {} + + +def test_determine_regime(): + a = SentimentAggregator() + assert a.determine_regime(15, None) == "risk_off" + assert a.determine_regime(15, 35.0) == "risk_off" + assert a.determine_regime(50, 35.0) == "risk_off" + assert a.determine_regime(70, 15.0) == "risk_on" + assert a.determine_regime(50, 20.0) == "neutral" diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py new file mode 100644 index 0000000..e00ffa6 --- /dev/null +++ b/shared/tests/test_sentiment_models.py @@ -0,0 +1,113 @@ +"""Tests for news and sentiment models.""" + +from datetime import UTC, datetime + +from shared.models import NewsCategory, NewsItem, OrderSide +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore + + +def test_news_item_defaults(): + item = NewsItem( + source="finnhub", + headline="Test headline", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + assert item.id + assert item.symbols == [] + assert item.summary is None + assert item.raw_data == {} + assert item.created_at is not None + + +def test_news_item_with_symbols(): + item = NewsItem( + source="rss", + headline="AAPL earnings beat", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.8, + category=NewsCategory.EARNINGS, + symbols=["AAPL"], + ) + assert item.symbols == ["AAPL"] + assert item.category == NewsCategory.EARNINGS + + +def test_news_category_values(): + assert NewsCategory.POLICY == "policy" + assert NewsCategory.EARNINGS == "earnings" + assert NewsCategory.MACRO == "macro" + assert NewsCategory.SOCIAL == "social" + assert NewsCategory.FILING == "filing" + assert NewsCategory.FED == "fed" + + +def test_symbol_score(): + score = SymbolScore( + symbol="AAPL", + news_score=0.5, + news_count=10, + social_score=0.3, + policy_score=0.0, + filing_score=0.2, + composite=0.3, + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + assert score.symbol == "AAPL" + assert score.composite == 0.3 + + +def test_market_sentiment(): + ms = MarketSentiment( + fear_greed=25, + fear_greed_label="Extreme Fear", + vix=32.5, + fed_stance="hawkish", + market_regime="risk_off", + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + assert ms.market_regime == "risk_off" + assert ms.vix == 32.5 + + +def test_market_sentiment_no_vix(): + ms = MarketSentiment( + fear_greed=50, + fear_greed_label="Neutral", + fed_stance="neutral", + market_regime="neutral", + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + assert ms.vix is None + + +def test_selected_stock(): + ss = SelectedStock( + symbol="NVDA", + side=OrderSide.BUY, + conviction=0.85, + reason="CHIPS Act expansion", + key_news=["Trump signs CHIPS Act expansion"], + ) + assert ss.conviction == 0.85 + assert len(ss.key_news) == 1 + + +def test_candidate(): + c = Candidate( + symbol="TSLA", + source="sentiment", + direction=OrderSide.BUY, + score=0.75, + reason="High social buzz", + ) + assert c.direction == OrderSide.BUY + + c2 = Candidate( + symbol="XOM", + source="llm", + score=0.6, + reason="Oil price surge", + ) + assert c2.direction is None |
