summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/alembic/versions/001_initial_schema.py10
-rw-r--r--shared/alembic/versions/002_news_sentiment_tables.py84
-rw-r--r--shared/alembic/versions/003_add_missing_indexes.py35
-rw-r--r--shared/alembic/versions/004_add_signal_detail_columns.py25
-rw-r--r--shared/pyproject.toml32
-rw-r--r--shared/src/shared/broker.py12
-rw-r--r--shared/src/shared/config.py51
-rw-r--r--shared/src/shared/db.py276
-rw-r--r--shared/src/shared/events.py37
-rw-r--r--shared/src/shared/healthcheck.py5
-rw-r--r--shared/src/shared/metrics.py2
-rw-r--r--shared/src/shared/models.py44
-rw-r--r--shared/src/shared/notifier.py49
-rw-r--r--shared/src/shared/resilience.py146
-rw-r--r--shared/src/shared/sa_models.py64
-rw-r--r--shared/src/shared/sentiment.py141
-rw-r--r--shared/src/shared/sentiment_models.py43
-rw-r--r--shared/src/shared/shutdown.py30
-rw-r--r--shared/tests/test_alpaca.py4
-rw-r--r--shared/tests/test_broker.py5
-rw-r--r--shared/tests/test_config_validation.py29
-rw-r--r--shared/tests/test_db.py69
-rw-r--r--shared/tests/test_db_news.py79
-rw-r--r--shared/tests/test_events.py10
-rw-r--r--shared/tests/test_models.py20
-rw-r--r--shared/tests/test_news_events.py56
-rw-r--r--shared/tests/test_notifier.py2
-rw-r--r--shared/tests/test_resilience.py203
-rw-r--r--shared/tests/test_sa_models.py53
-rw-r--r--shared/tests/test_sa_news_models.py29
-rw-r--r--shared/tests/test_sentiment.py44
-rw-r--r--shared/tests/test_sentiment_aggregator.py79
-rw-r--r--shared/tests/test_sentiment_models.py113
33 files changed, 1517 insertions, 364 deletions
diff --git a/shared/alembic/versions/001_initial_schema.py b/shared/alembic/versions/001_initial_schema.py
index 2bdaafc..7b744ee 100644
--- a/shared/alembic/versions/001_initial_schema.py
+++ b/shared/alembic/versions/001_initial_schema.py
@@ -5,16 +5,16 @@ Revises:
Create Date: 2026-04-01
"""
-from typing import Sequence, Union
+from collections.abc import Sequence
-from alembic import op
import sqlalchemy as sa
+from alembic import op
# revision identifiers, used by Alembic.
revision: str = "001"
-down_revision: Union[str, None] = None
-branch_labels: Union[str, Sequence[str], None] = None
-depends_on: Union[str, Sequence[str], None] = None
+down_revision: str | None = None
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
diff --git a/shared/alembic/versions/002_news_sentiment_tables.py b/shared/alembic/versions/002_news_sentiment_tables.py
new file mode 100644
index 0000000..d85a634
--- /dev/null
+++ b/shared/alembic/versions/002_news_sentiment_tables.py
@@ -0,0 +1,84 @@
+"""Add news, sentiment, and stock selection tables
+
+Revision ID: 002
+Revises: 001
+Create Date: 2026-04-02
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+from alembic import op
+
+revision: str = "002"
+down_revision: str | None = "001"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ op.create_table(
+ "news_items",
+ sa.Column("id", sa.Text, primary_key=True),
+ sa.Column("source", sa.Text, nullable=False),
+ sa.Column("headline", sa.Text, nullable=False),
+ sa.Column("summary", sa.Text),
+ sa.Column("url", sa.Text),
+ sa.Column("published_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("symbols", sa.Text),
+ sa.Column("sentiment", sa.Float, nullable=False),
+ sa.Column("category", sa.Text, nullable=False),
+ sa.Column("raw_data", sa.Text),
+ sa.Column(
+ "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()
+ ),
+ )
+ op.create_index("idx_news_items_published", "news_items", ["published_at"])
+ op.create_index("idx_news_items_source", "news_items", ["source"])
+
+ op.create_table(
+ "symbol_scores",
+ sa.Column("id", sa.Text, primary_key=True),
+ sa.Column("symbol", sa.Text, nullable=False, unique=True),
+ sa.Column("news_score", sa.Float, nullable=False, server_default="0"),
+ sa.Column("news_count", sa.Integer, nullable=False, server_default="0"),
+ sa.Column("social_score", sa.Float, nullable=False, server_default="0"),
+ sa.Column("policy_score", sa.Float, nullable=False, server_default="0"),
+ sa.Column("filing_score", sa.Float, nullable=False, server_default="0"),
+ sa.Column("composite", sa.Float, nullable=False, server_default="0"),
+ sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
+ )
+
+ op.create_table(
+ "market_sentiment",
+ sa.Column("id", sa.Text, primary_key=True),
+ sa.Column("fear_greed", sa.Integer, nullable=False),
+ sa.Column("fear_greed_label", sa.Text, nullable=False),
+ sa.Column("vix", sa.Float),
+ sa.Column("fed_stance", sa.Text, nullable=False, server_default="neutral"),
+ sa.Column("market_regime", sa.Text, nullable=False, server_default="neutral"),
+ sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
+ )
+
+ op.create_table(
+ "stock_selections",
+ sa.Column("id", sa.Text, primary_key=True),
+ sa.Column("trade_date", sa.Date, nullable=False),
+ sa.Column("symbol", sa.Text, nullable=False),
+ sa.Column("side", sa.Text, nullable=False),
+ sa.Column("conviction", sa.Float, nullable=False),
+ sa.Column("reason", sa.Text, nullable=False),
+ sa.Column("key_news", sa.Text),
+ sa.Column("sentiment_snapshot", sa.Text),
+ sa.Column(
+ "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()
+ ),
+ )
+ op.create_index("idx_stock_selections_date", "stock_selections", ["trade_date"])
+
+
+def downgrade() -> None:
+ op.drop_table("stock_selections")
+ op.drop_table("market_sentiment")
+ op.drop_table("symbol_scores")
+ op.drop_table("news_items")
diff --git a/shared/alembic/versions/003_add_missing_indexes.py b/shared/alembic/versions/003_add_missing_indexes.py
new file mode 100644
index 0000000..7a252d4
--- /dev/null
+++ b/shared/alembic/versions/003_add_missing_indexes.py
@@ -0,0 +1,35 @@
+"""Add missing indexes for common query patterns.
+
+Revision ID: 003
+Revises: 002
+Create Date: 2026-04-02
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+
+revision: str = "003"
+down_revision: str | None = "002"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ op.create_index("idx_signals_symbol_created", "signals", ["symbol", "created_at"])
+ op.create_index(
+ "idx_orders_symbol_status_created", "orders", ["symbol", "status", "created_at"]
+ )
+ op.create_index("idx_trades_order_id", "trades", ["order_id"])
+ op.create_index("idx_trades_symbol_traded", "trades", ["symbol", "traded_at"])
+ op.create_index("idx_portfolio_snapshots_at", "portfolio_snapshots", ["snapshot_at"])
+ op.create_index("idx_symbol_scores_symbol", "symbol_scores", ["symbol"], unique=True)
+
+
+def downgrade() -> None:
+ op.drop_index("idx_symbol_scores_symbol", table_name="symbol_scores")
+ op.drop_index("idx_portfolio_snapshots_at", table_name="portfolio_snapshots")
+ op.drop_index("idx_trades_symbol_traded", table_name="trades")
+ op.drop_index("idx_trades_order_id", table_name="trades")
+ op.drop_index("idx_orders_symbol_status_created", table_name="orders")
+ op.drop_index("idx_signals_symbol_created", table_name="signals")
diff --git a/shared/alembic/versions/004_add_signal_detail_columns.py b/shared/alembic/versions/004_add_signal_detail_columns.py
new file mode 100644
index 0000000..4009b6e
--- /dev/null
+++ b/shared/alembic/versions/004_add_signal_detail_columns.py
@@ -0,0 +1,25 @@
+"""Add conviction, stop_loss, take_profit columns to signals table.
+
+Revision ID: 004
+Revises: 003
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+revision = "004"
+down_revision = "003"
+
+
+def upgrade():
+ op.add_column(
+ "signals", sa.Column("conviction", sa.Float, nullable=False, server_default="1.0")
+ )
+ op.add_column("signals", sa.Column("stop_loss", sa.Numeric, nullable=True))
+ op.add_column("signals", sa.Column("take_profit", sa.Numeric, nullable=True))
+
+
+def downgrade():
+ op.drop_column("signals", "take_profit")
+ op.drop_column("signals", "stop_loss")
+ op.drop_column("signals", "conviction")
diff --git a/shared/pyproject.toml b/shared/pyproject.toml
index 830088d..eb74a11 100644
--- a/shared/pyproject.toml
+++ b/shared/pyproject.toml
@@ -4,28 +4,22 @@ version = "0.1.0"
description = "Shared models, events, and utilities for trading platform"
requires-python = ">=3.12"
dependencies = [
- "pydantic>=2.0",
- "pydantic-settings>=2.0",
- "redis>=5.0",
- "asyncpg>=0.29",
- "sqlalchemy[asyncio]>=2.0",
- "alembic>=1.13",
- "structlog>=24.0",
- "prometheus-client>=0.20",
- "pyyaml>=6.0",
- "aiohttp>=3.9",
- "rich>=13.0",
+ "pydantic>=2.8,<3",
+ "pydantic-settings>=2.0,<3",
+ "redis>=5.0,<6",
+ "asyncpg>=0.29,<1",
+ "sqlalchemy[asyncio]>=2.0,<3",
+ "alembic>=1.13,<2",
+ "structlog>=24.0,<25",
+ "prometheus-client>=0.20,<1",
+ "pyyaml>=6.0,<7",
+ "aiohttp>=3.9,<4",
+ "rich>=13.0,<14",
]
[project.optional-dependencies]
-dev = [
- "pytest>=8.0",
- "pytest-asyncio>=0.23",
- "ruff>=0.4",
-]
-claude = [
- "anthropic>=0.40",
-]
+dev = ["pytest>=8.0,<9", "pytest-asyncio>=0.23,<1", "ruff>=0.4,<1"]
+claude = ["anthropic>=0.40,<1"]
[build-system]
requires = ["hatchling"]
diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py
index fbe4576..2b96714 100644
--- a/shared/src/shared/broker.py
+++ b/shared/src/shared/broker.py
@@ -5,13 +5,21 @@ from typing import Any
import redis.asyncio
+from shared.resilience import retry_async
+
class RedisBroker:
"""Async Redis Streams broker for publishing and reading events."""
def __init__(self, redis_url: str) -> None:
- self._redis = redis.asyncio.from_url(redis_url)
+ self._redis = redis.asyncio.from_url(
+ redis_url,
+ socket_keepalive=True,
+ health_check_interval=30,
+ retry_on_timeout=True,
+ )
+ @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,))
async def publish(self, stream: str, data: dict[str, Any]) -> None:
"""Publish a message to a Redis stream."""
payload = json.dumps(data)
@@ -25,6 +33,7 @@ class RedisBroker:
if "BUSYGROUP" not in str(e):
raise
+ @retry_async(max_retries=3, base_delay=0.5, exclude=(ValueError,))
async def read_group(
self,
stream: str,
@@ -99,6 +108,7 @@ class RedisBroker:
messages.append(json.loads(payload))
return messages
+ @retry_async(max_retries=2, base_delay=0.5)
async def ping(self) -> bool:
"""Ping the Redis server; return True if reachable."""
return await self._redis.ping()
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
index 4e8e7f1..0f1c66e 100644
--- a/shared/src/shared/config.py
+++ b/shared/src/shared/config.py
@@ -1,14 +1,18 @@
"""Shared configuration settings for the trading platform."""
+from pydantic import SecretStr, field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
- alpaca_api_key: str = ""
- alpaca_api_secret: str = ""
+ alpaca_api_key: SecretStr = SecretStr("")
+ alpaca_api_secret: SecretStr = SecretStr("")
alpaca_paper: bool = True # Use paper trading by default
- redis_url: str = "redis://localhost:6379"
- database_url: str = "postgresql://trading:trading@localhost:5432/trading"
+ redis_url: SecretStr = SecretStr("redis://localhost:6379")
+ database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading")
+ db_pool_size: int = 20
+ db_max_overflow: int = 10
+ db_pool_recycle: int = 3600
log_level: str = "INFO"
risk_max_position_size: float = 0.1
risk_stop_loss_pct: float = 5.0
@@ -27,12 +31,45 @@ class Settings(BaseSettings):
risk_max_consecutive_losses: int = 5
risk_loss_pause_minutes: int = 60
dry_run: bool = True
- telegram_bot_token: str = ""
+ telegram_bot_token: SecretStr = SecretStr("")
telegram_chat_id: str = ""
telegram_enabled: bool = False
log_format: str = "json"
health_port: int = 8080
- circuit_breaker_threshold: int = 5
- circuit_breaker_timeout: int = 60
metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token
+ # API security
+ api_auth_token: SecretStr = SecretStr("")
+ cors_origins: str = "http://localhost:3000"
+ # News collector
+ finnhub_api_key: SecretStr = SecretStr("")
+ news_poll_interval: int = 300
+ sentiment_aggregate_interval: int = 900
+ # Stock selector
+ selector_final_time: str = "15:30"
+ selector_max_picks: int = 3
+ # LLM
+ anthropic_api_key: SecretStr = SecretStr("")
+ anthropic_model: str = "claude-sonnet-4-20250514"
model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}
+
+ @field_validator("risk_max_position_size")
+ @classmethod
+ def validate_position_size(cls, v: float) -> float:
+ if v <= 0 or v > 1:
+ raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)")
+ return v
+
+ @field_validator("health_port")
+ @classmethod
+ def validate_health_port(cls, v: int) -> int:
+ if v < 1024 or v > 65535:
+ raise ValueError("health_port must be between 1024 and 65535")
+ return v
+
+ @field_validator("log_level")
+ @classmethod
+ def validate_log_level(cls, v: str) -> str:
+ valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
+ if v.upper() not in valid:
+ raise ValueError(f"log_level must be one of {valid}")
+ return v.upper()
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 901e293..8fee000 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,15 +1,27 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
+import json
+import uuid
from contextlib import asynccontextmanager
-from datetime import datetime, timedelta, timezone
+from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
-from typing import Optional
from sqlalchemy import select, update
-from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
-from shared.models import Candle, Signal, Order, OrderStatus
-from shared.sa_models import Base, CandleRow, SignalRow, OrderRow, PortfolioSnapshotRow
+from shared.models import Candle, NewsItem, Order, OrderStatus, Signal
+from shared.sa_models import (
+ Base,
+ CandleRow,
+ MarketSentimentRow,
+ NewsItemRow,
+ OrderRow,
+ PortfolioSnapshotRow,
+ SignalRow,
+ StockSelectionRow,
+ SymbolScoreRow,
+)
+from shared.sentiment_models import MarketSentiment, SymbolScore
class Database:
@@ -23,9 +35,24 @@ class Database:
self._engine = None
self._session_factory = None
- async def connect(self) -> None:
+ async def connect(
+ self,
+ pool_size: int = 20,
+ max_overflow: int = 10,
+ pool_recycle: int = 3600,
+ ) -> None:
"""Create the async engine, session factory, and all tables."""
- self._engine = create_async_engine(self._database_url)
+ if self._database_url.startswith("sqlite"):
+ # SQLite doesn't support pooling options
+ self._engine = create_async_engine(self._database_url)
+ else:
+ self._engine = create_async_engine(
+ self._database_url,
+ pool_pre_ping=True,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ )
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@@ -85,6 +112,9 @@ class Database:
price=signal.price,
quantity=signal.quantity,
reason=signal.reason,
+ conviction=signal.conviction,
+ stop_loss=signal.stop_loss,
+ take_profit=signal.take_profit,
created_at=signal.created_at,
)
async with self._session_factory() as session:
@@ -121,7 +151,7 @@ class Database:
self,
order_id: str,
status: OrderStatus,
- filled_at: Optional[datetime] = None,
+ filled_at: datetime | None = None,
) -> None:
"""Update the status (and optionally filled_at) of an order."""
stmt = (
@@ -167,7 +197,7 @@ class Database:
total_value=total_value,
realized_pnl=realized_pnl,
unrealized_pnl=unrealized_pnl,
- snapshot_at=datetime.now(timezone.utc),
+ snapshot_at=datetime.now(UTC),
)
session.add(row)
await session.commit()
@@ -178,7 +208,7 @@ class Database:
async def get_portfolio_snapshots(self, days: int = 30) -> list[dict]:
"""Retrieve recent portfolio snapshots."""
async with self.get_session() as session:
- since = datetime.now(timezone.utc) - timedelta(days=days)
+ since = datetime.now(UTC) - timedelta(days=days)
stmt = (
select(PortfolioSnapshotRow)
.where(PortfolioSnapshotRow.snapshot_at >= since)
@@ -195,3 +225,229 @@ class Database:
}
for r in rows
]
+
+ async def insert_news_item(self, item: NewsItem) -> None:
+ """Insert a NewsItem row, JSON-encoding symbols and raw_data."""
+ row = NewsItemRow(
+ id=item.id,
+ source=item.source,
+ headline=item.headline,
+ summary=item.summary,
+ url=item.url,
+ published_at=item.published_at,
+ symbols=json.dumps(item.symbols),
+ sentiment=item.sentiment,
+ category=item.category.value,
+ raw_data=json.dumps(item.raw_data),
+ created_at=item.created_at,
+ )
+ async with self._session_factory() as session:
+ try:
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_recent_news(self, hours: int = 24) -> list[dict]:
+ """Retrieve news items published in the last N hours."""
+ since = datetime.now(UTC) - timedelta(hours=hours)
+ stmt = (
+ select(NewsItemRow)
+ .where(NewsItemRow.published_at >= since)
+ .order_by(NewsItemRow.published_at.desc())
+ )
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ except Exception:
+ await session.rollback()
+ raise
+ return [
+ {
+ "id": r.id,
+ "source": r.source,
+ "headline": r.headline,
+ "summary": r.summary,
+ "url": r.url,
+ "published_at": r.published_at,
+ "symbols": json.loads(r.symbols) if r.symbols else [],
+ "sentiment": r.sentiment,
+ "category": r.category,
+ "raw_data": json.loads(r.raw_data) if r.raw_data else {},
+ "created_at": r.created_at,
+ }
+ for r in rows
+ ]
+
+ async def upsert_symbol_score(self, score: SymbolScore) -> None:
+ """Insert or update a SymbolScore row, keyed by symbol."""
+ async with self._session_factory() as session:
+ try:
+ stmt = select(SymbolScoreRow).where(SymbolScoreRow.symbol == score.symbol)
+ result = await session.execute(stmt)
+ existing = result.scalar_one_or_none()
+ if existing is not None:
+ existing.news_score = score.news_score
+ existing.news_count = score.news_count
+ existing.social_score = score.social_score
+ existing.policy_score = score.policy_score
+ existing.filing_score = score.filing_score
+ existing.composite = score.composite
+ existing.updated_at = score.updated_at
+ else:
+ row = SymbolScoreRow(
+ id=str(uuid.uuid4()),
+ symbol=score.symbol,
+ news_score=score.news_score,
+ news_count=score.news_count,
+ social_score=score.social_score,
+ policy_score=score.policy_score,
+ filing_score=score.filing_score,
+ composite=score.composite,
+ updated_at=score.updated_at,
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_top_symbol_scores(self, limit: int = 20) -> list[dict]:
+ """Retrieve top symbol scores ordered by composite descending."""
+ stmt = select(SymbolScoreRow).order_by(SymbolScoreRow.composite.desc()).limit(limit)
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ except Exception:
+ await session.rollback()
+ raise
+ return [
+ {
+ "id": r.id,
+ "symbol": r.symbol,
+ "news_score": r.news_score,
+ "news_count": r.news_count,
+ "social_score": r.social_score,
+ "policy_score": r.policy_score,
+ "filing_score": r.filing_score,
+ "composite": r.composite,
+ "updated_at": r.updated_at,
+ }
+ for r in rows
+ ]
+
+ async def upsert_market_sentiment(self, ms: MarketSentiment) -> None:
+ """Insert or update the single 'latest' market sentiment row."""
+ async with self._session_factory() as session:
+ try:
+ stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest")
+ result = await session.execute(stmt)
+ existing = result.scalar_one_or_none()
+ if existing is not None:
+ existing.fear_greed = ms.fear_greed
+ existing.fear_greed_label = ms.fear_greed_label
+ existing.vix = ms.vix
+ existing.fed_stance = ms.fed_stance
+ existing.market_regime = ms.market_regime
+ existing.updated_at = ms.updated_at
+ else:
+ row = MarketSentimentRow(
+ id="latest",
+ fear_greed=ms.fear_greed,
+ fear_greed_label=ms.fear_greed_label,
+ vix=ms.vix,
+ fed_stance=ms.fed_stance,
+ market_regime=ms.market_regime,
+ updated_at=ms.updated_at,
+ )
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_latest_market_sentiment(self) -> dict | None:
+ """Retrieve the 'latest' market sentiment row, or None if not found."""
+ stmt = select(MarketSentimentRow).where(MarketSentimentRow.id == "latest")
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ row = result.scalar_one_or_none()
+ except Exception:
+ await session.rollback()
+ raise
+ if row is None:
+ return None
+ return {
+ "id": row.id,
+ "fear_greed": row.fear_greed,
+ "fear_greed_label": row.fear_greed_label,
+ "vix": row.vix,
+ "fed_stance": row.fed_stance,
+ "market_regime": row.market_regime,
+ "updated_at": row.updated_at,
+ }
+
+ async def insert_stock_selection(
+ self,
+ trade_date: date,
+ symbol: str,
+ side: str,
+ conviction: float,
+ reason: str,
+ key_news: list,
+ sentiment_snapshot: dict,
+ ) -> None:
+ """Insert a stock selection row with JSON-encoded lists/dicts."""
+ row = StockSelectionRow(
+ id=str(uuid.uuid4()),
+ trade_date=trade_date,
+ symbol=symbol,
+ side=side,
+ conviction=conviction,
+ reason=reason,
+ key_news=json.dumps(key_news),
+ sentiment_snapshot=json.dumps(sentiment_snapshot),
+ created_at=datetime.now(UTC),
+ )
+ async with self._session_factory() as session:
+ try:
+ session.add(row)
+ await session.commit()
+ except Exception:
+ await session.rollback()
+ raise
+
+ async def get_stock_selections(self, trade_date: date) -> list[dict]:
+ """Retrieve stock selections for a given trade date."""
+ stmt = (
+ select(StockSelectionRow)
+ .where(StockSelectionRow.trade_date == trade_date)
+ .order_by(StockSelectionRow.conviction.desc())
+ )
+ async with self._session_factory() as session:
+ try:
+ result = await session.execute(stmt)
+ rows = result.scalars().all()
+ except Exception:
+ await session.rollback()
+ raise
+ return [
+ {
+ "id": r.id,
+ "trade_date": r.trade_date,
+ "symbol": r.symbol,
+ "side": r.side,
+ "conviction": r.conviction,
+ "reason": r.reason,
+ "key_news": json.loads(r.key_news) if r.key_news else [],
+ "sentiment_snapshot": json.loads(r.sentiment_snapshot)
+ if r.sentiment_snapshot
+ else {},
+ "created_at": r.created_at,
+ }
+ for r in rows
+ ]
diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py
index 72f8865..37217a0 100644
--- a/shared/src/shared/events.py
+++ b/shared/src/shared/events.py
@@ -1,17 +1,18 @@
"""Event types and serialization for the trading platform."""
-from enum import Enum
+from enum import StrEnum
from typing import Any
from pydantic import BaseModel
-from shared.models import Candle, Signal, Order
+from shared.models import Candle, NewsItem, Order, Signal
-class EventType(str, Enum):
+class EventType(StrEnum):
CANDLE = "CANDLE"
SIGNAL = "SIGNAL"
ORDER = "ORDER"
+ NEWS = "NEWS"
class CandleEvent(BaseModel):
@@ -59,10 +60,26 @@ class OrderEvent(BaseModel):
return cls(type=raw["type"], data=Order(**raw["data"]))
+class NewsEvent(BaseModel):
+ type: EventType = EventType.NEWS
+ data: NewsItem
+
+ def to_dict(self) -> dict:
+ return {
+ "type": self.type,
+ "data": self.data.model_dump(mode="json"),
+ }
+
+ @classmethod
+ def from_raw(cls, raw: dict) -> "NewsEvent":
+ return cls(type=raw["type"], data=NewsItem(**raw["data"]))
+
+
_EVENT_TYPE_MAP = {
EventType.CANDLE: CandleEvent,
EventType.SIGNAL: SignalEvent,
EventType.ORDER: OrderEvent,
+ EventType.NEWS: NewsEvent,
}
@@ -71,6 +88,16 @@ class Event:
@staticmethod
def from_dict(data: dict) -> Any:
- event_type = EventType(data["type"])
+ """Deserialize a raw dict into the appropriate event type.
+
+ Raises ValueError for malformed or unrecognized event data.
+ """
+ try:
+ event_type = EventType(data["type"])
+ except (KeyError, ValueError) as exc:
+ raise ValueError(f"Invalid or missing event type in data: {data!r}") from exc
cls = _EVENT_TYPE_MAP[event_type]
- return cls.from_raw(data)
+ try:
+ return cls.from_raw(data)
+ except KeyError as exc:
+ raise ValueError(f"Missing required field in {event_type} event data: {exc}") from exc
diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py
index 7411e8a..a19705b 100644
--- a/shared/src/shared/healthcheck.py
+++ b/shared/src/shared/healthcheck.py
@@ -3,10 +3,11 @@
from __future__ import annotations
import time
-from typing import Any, Callable, Awaitable
+from collections.abc import Awaitable, Callable
+from typing import Any
from aiohttp import web
-from prometheus_client import CollectorRegistry, REGISTRY, generate_latest, CONTENT_TYPE_LATEST
+from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, CollectorRegistry, generate_latest
class HealthCheckServer:
diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py
index cd239f3..6189143 100644
--- a/shared/src/shared/metrics.py
+++ b/shared/src/shared/metrics.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY
+from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge, Histogram
class ServiceMetrics:
diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py
index 70820b5..f357f9f 100644
--- a/shared/src/shared/models.py
+++ b/shared/src/shared/models.py
@@ -1,25 +1,24 @@
"""Shared Pydantic models for the trading platform."""
import uuid
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-from enum import Enum
-from typing import Optional
+from enum import StrEnum
from pydantic import BaseModel, Field, computed_field
-class OrderSide(str, Enum):
+class OrderSide(StrEnum):
BUY = "BUY"
SELL = "SELL"
-class OrderType(str, Enum):
+class OrderType(StrEnum):
MARKET = "MARKET"
LIMIT = "LIMIT"
-class OrderStatus(str, Enum):
+class OrderStatus(StrEnum):
PENDING = "PENDING"
FILLED = "FILLED"
CANCELLED = "CANCELLED"
@@ -46,9 +45,9 @@ class Signal(BaseModel):
quantity: Decimal
reason: str
conviction: float = 1.0 # 0.0 to 1.0, signal strength/confidence
- stop_loss: Optional[Decimal] = None # Price to exit at loss
- take_profit: Optional[Decimal] = None # Price to exit at profit
- created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ stop_loss: Decimal | None = None # Price to exit at loss
+ take_profit: Decimal | None = None # Price to exit at profit
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
class Order(BaseModel):
@@ -60,8 +59,8 @@ class Order(BaseModel):
price: Decimal
quantity: Decimal
status: OrderStatus = OrderStatus.PENDING
- created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
- filled_at: Optional[datetime] = None
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
+ filled_at: datetime | None = None
class Position(BaseModel):
@@ -74,3 +73,26 @@ class Position(BaseModel):
@property
def unrealized_pnl(self) -> Decimal:
return self.quantity * (self.current_price - self.avg_entry_price)
+
+
+class NewsCategory(StrEnum):
+ POLICY = "policy"
+ EARNINGS = "earnings"
+ MACRO = "macro"
+ SOCIAL = "social"
+ FILING = "filing"
+ FED = "fed"
+
+
+class NewsItem(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ source: str
+ headline: str
+ summary: str | None = None
+ url: str | None = None
+ published_at: datetime
+ symbols: list[str] = []
+ sentiment: float
+ category: NewsCategory
+ raw_data: dict = {}
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
diff --git a/shared/src/shared/notifier.py b/shared/src/shared/notifier.py
index f03919c..cfc86cd 100644
--- a/shared/src/shared/notifier.py
+++ b/shared/src/shared/notifier.py
@@ -2,12 +2,13 @@
import asyncio
import logging
+from collections.abc import Sequence
from decimal import Decimal
-from typing import Optional, Sequence
import aiohttp
-from shared.models import Signal, Order, Position
+from shared.models import Order, Position, Signal
+from shared.sentiment_models import MarketSentiment, SelectedStock
logger = logging.getLogger(__name__)
@@ -22,7 +23,7 @@ class TelegramNotifier:
self._bot_token = bot_token
self._chat_id = chat_id
self._semaphore = asyncio.Semaphore(1)
- self._session: Optional[aiohttp.ClientSession] = None
+ self._session: aiohttp.ClientSession | None = None
@property
def enabled(self) -> bool:
@@ -112,17 +113,45 @@ class TelegramNotifier:
"",
"<b>Positions:</b>",
]
- for pos in positions:
- lines.append(
- f" {pos.symbol}: qty={pos.quantity} "
- f"entry={pos.avg_entry_price} "
- f"current={pos.current_price} "
- f"pnl={pos.unrealized_pnl}"
- )
+ lines.extend(
+ f" {pos.symbol}: qty={pos.quantity} "
+ f"entry={pos.avg_entry_price} "
+ f"current={pos.current_price} "
+ f"pnl={pos.unrealized_pnl}"
+ for pos in positions
+ )
if not positions:
lines.append(" No open positions")
await self.send("\n".join(lines))
+ async def send_stock_selection(
+ self,
+ selections: list[SelectedStock],
+ market: MarketSentiment | None = None,
+ ) -> None:
+ """Format and send stock selection notification."""
+ lines = [f"<b>📊 Stock Selection ({len(selections)} picks)</b>", ""]
+
+ side_emoji = {"BUY": "🟢", "SELL": "🔴"}
+
+ for i, s in enumerate(selections, 1):
+ emoji = side_emoji.get(s.side.value, "⚪")
+ lines.append(
+ f"{i}. <b>{s.symbol}</b> {emoji} {s.side.value} (conviction: {s.conviction:.0%})"
+ )
+ lines.append(f" {s.reason}")
+ if s.key_news:
+ lines.append(f" News: {s.key_news[0]}")
+ lines.append("")
+
+ if market:
+ lines.append(
+ f"Market: F&amp;G {market.fear_greed} ({market.fear_greed_label})"
+ + (f" | VIX {market.vix:.1f}" if market.vix else "")
+ )
+
+ await self.send("\n".join(lines))
+
async def close(self) -> None:
"""Close the underlying aiohttp session."""
if self._session is not None:
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
index e43fd21..66225d7 100644
--- a/shared/src/shared/resilience.py
+++ b/shared/src/shared/resilience.py
@@ -1,29 +1,45 @@
-"""Retry with exponential backoff and circuit breaker utilities."""
+"""Resilience utilities for the trading platform.
+
+Provides retry, circuit breaker, and timeout primitives using only stdlib.
+No external dependencies required.
+"""
from __future__ import annotations
import asyncio
-import enum
import functools
import logging
import random
import time
-from typing import Any, Callable
+from collections.abc import Callable
+from contextlib import asynccontextmanager
+from enum import StrEnum
+from typing import Any
-logger = logging.getLogger(__name__)
+
+class _State(StrEnum):
+ CLOSED = "closed"
+ OPEN = "open"
+ HALF_OPEN = "half_open"
-# ---------------------------------------------------------------------------
-# retry_with_backoff
-# ---------------------------------------------------------------------------
+logger = logging.getLogger(__name__)
-def retry_with_backoff(
+def retry_async(
max_retries: int = 3,
base_delay: float = 1.0,
- max_delay: float = 60.0,
+ max_delay: float = 30.0,
+ exclude: tuple[type[BaseException], ...] = (),
) -> Callable:
- """Decorator that retries an async function with exponential backoff + jitter."""
+ """Decorator: exponential backoff + jitter for async functions.
+
+ Parameters:
+ max_retries: Maximum number of retry attempts (after the initial call).
+ base_delay: Base delay in seconds for exponential backoff.
+ max_delay: Maximum delay cap in seconds.
+ exclude: Exception types that should NOT be retried (raised immediately).
+ """
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
@@ -33,20 +49,21 @@ def retry_with_backoff(
try:
return await func(*args, **kwargs)
except Exception as exc:
+ if exclude and isinstance(exc, exclude):
+ raise
last_exc = exc
if attempt < max_retries:
delay = min(base_delay * (2**attempt), max_delay)
- jitter = delay * random.uniform(0, 0.5)
- total_delay = delay + jitter
+ jitter_delay = delay * random.uniform(0.5, 1.0)
logger.warning(
- "Retry %d/%d for %s after error: %s (delay=%.3fs)",
+ "Retry %d/%d for %s in %.2fs: %s",
attempt + 1,
max_retries,
func.__name__,
+ jitter_delay,
exc,
- total_delay,
)
- await asyncio.sleep(total_delay)
+ await asyncio.sleep(jitter_delay)
raise last_exc # type: ignore[misc]
return wrapper
@@ -54,52 +71,65 @@ def retry_with_backoff(
return decorator
-# ---------------------------------------------------------------------------
-# CircuitBreaker
-# ---------------------------------------------------------------------------
-
-
-class CircuitState(enum.Enum):
- CLOSED = "closed"
- OPEN = "open"
- HALF_OPEN = "half_open"
+class CircuitBreaker:
+ """Circuit breaker: opens after N consecutive failures, auto-recovers.
+ States: closed -> open -> half_open -> closed
-class CircuitBreaker:
- """Simple circuit breaker implementation."""
+ Parameters:
+ failure_threshold: Number of consecutive failures before opening.
+ cooldown: Seconds to wait before allowing a half-open probe.
+ """
- def __init__(
- self,
- failure_threshold: int = 5,
- recovery_timeout: float = 60.0,
- ) -> None:
+ def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None:
self._failure_threshold = failure_threshold
- self._recovery_timeout = recovery_timeout
- self._failure_count: int = 0
- self._state = CircuitState.CLOSED
+ self._cooldown = cooldown
+ self._failures = 0
+ self._state = _State.CLOSED
self._opened_at: float = 0.0
- @property
- def state(self) -> CircuitState:
- return self._state
-
- def allow_request(self) -> bool:
- if self._state == CircuitState.CLOSED:
- return True
- if self._state == CircuitState.OPEN:
- if time.monotonic() - self._opened_at >= self._recovery_timeout:
- self._state = CircuitState.HALF_OPEN
- return True
- return False
- # HALF_OPEN
- return True
-
- def record_success(self) -> None:
- self._failure_count = 0
- self._state = CircuitState.CLOSED
-
- def record_failure(self) -> None:
- self._failure_count += 1
- if self._failure_count >= self._failure_threshold:
- self._state = CircuitState.OPEN
- self._opened_at = time.monotonic()
+ async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+ """Execute func through the breaker."""
+ if self._state == _State.OPEN:
+ if time.monotonic() - self._opened_at >= self._cooldown:
+ self._state = _State.HALF_OPEN
+ else:
+ raise RuntimeError("Circuit breaker is open")
+
+ try:
+ result = await func(*args, **kwargs)
+ except Exception:
+ self._failures += 1
+ if self._state == _State.HALF_OPEN:
+ self._state = _State.OPEN
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker re-opened after half-open probe failure (threshold=%d)",
+ self._failure_threshold,
+ )
+ elif self._failures >= self._failure_threshold:
+ self._state = _State.OPEN
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker opened after %d consecutive failures",
+ self._failures,
+ )
+ raise
+
+ # Success: reset
+ self._failures = 0
+ self._state = _State.CLOSED
+ return result
+
+
+@asynccontextmanager
+async def async_timeout(seconds: float):
+ """Async context manager wrapping asyncio.timeout().
+
+ Raises TimeoutError with a descriptive message on timeout.
+ """
+ try:
+ async with asyncio.timeout(seconds):
+ yield
+ except TimeoutError:
+ raise TimeoutError(f"Operation timed out after {seconds}s") from None
diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py
index 8386ba8..b70a6c4 100644
--- a/shared/src/shared/sa_models.py
+++ b/shared/src/shared/sa_models.py
@@ -3,6 +3,7 @@
from datetime import datetime
from decimal import Decimal
+import sqlalchemy as sa
from sqlalchemy import DateTime, ForeignKey, Numeric, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
@@ -34,6 +35,9 @@ class SignalRow(Base):
price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
reason: Mapped[str | None] = mapped_column(Text)
+ conviction: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="1.0")
+ stop_loss: Mapped[Decimal | None] = mapped_column(Numeric)
+ take_profit: Mapped[Decimal | None] = mapped_column(Numeric)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
@@ -83,3 +87,63 @@ class PortfolioSnapshotRow(Base):
realized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
unrealized_pnl: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
snapshot_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+
+
+class NewsItemRow(Base):
+ __tablename__ = "news_items"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ source: Mapped[str] = mapped_column(Text, nullable=False)
+ headline: Mapped[str] = mapped_column(Text, nullable=False)
+ summary: Mapped[str | None] = mapped_column(Text)
+ url: Mapped[str | None] = mapped_column(Text)
+ published_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+ symbols: Mapped[str | None] = mapped_column(Text) # JSON-encoded list
+ sentiment: Mapped[float] = mapped_column(sa.Float, nullable=False)
+ category: Mapped[str] = mapped_column(Text, nullable=False)
+ raw_data: Mapped[str | None] = mapped_column(Text) # JSON string
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False, server_default=sa.func.now()
+ )
+
+
+class SymbolScoreRow(Base):
+ __tablename__ = "symbol_scores"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ symbol: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
+ news_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0")
+ news_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="0")
+ social_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0")
+ policy_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0")
+ filing_score: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0")
+ composite: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="0")
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+
+
+class MarketSentimentRow(Base):
+ __tablename__ = "market_sentiment"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ fear_greed: Mapped[int] = mapped_column(sa.Integer, nullable=False)
+ fear_greed_label: Mapped[str] = mapped_column(Text, nullable=False)
+ vix: Mapped[float | None] = mapped_column(sa.Float)
+ fed_stance: Mapped[str] = mapped_column(Text, nullable=False, server_default="neutral")
+ market_regime: Mapped[str] = mapped_column(Text, nullable=False, server_default="neutral")
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
+
+
+class StockSelectionRow(Base):
+ __tablename__ = "stock_selections"
+
+ id: Mapped[str] = mapped_column(Text, primary_key=True)
+ trade_date: Mapped[datetime] = mapped_column(sa.Date, nullable=False)
+ symbol: Mapped[str] = mapped_column(Text, nullable=False)
+ side: Mapped[str] = mapped_column(Text, nullable=False)
+ conviction: Mapped[float] = mapped_column(sa.Float, nullable=False)
+ reason: Mapped[str] = mapped_column(Text, nullable=False)
+ key_news: Mapped[str | None] = mapped_column(Text) # JSON string
+ sentiment_snapshot: Mapped[str | None] = mapped_column(Text) # JSON string
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False, server_default=sa.func.now()
+ )
diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py
index 8213b47..c56da3e 100644
--- a/shared/src/shared/sentiment.py
+++ b/shared/src/shared/sentiment.py
@@ -1,35 +1,106 @@
-"""Market sentiment data."""
-
-import logging
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class SentimentData:
- """Aggregated sentiment snapshot."""
-
- fear_greed_value: int | None = None
- fear_greed_label: str | None = None
- news_sentiment: float | None = None
- news_count: int = 0
- exchange_netflow: float | None = None
- timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
-
- @property
- def should_buy(self) -> bool:
- if self.fear_greed_value is not None and self.fear_greed_value > 70:
- return False
- if self.news_sentiment is not None and self.news_sentiment < -0.3:
- return False
- return True
-
- @property
- def should_block(self) -> bool:
- if self.fear_greed_value is not None and self.fear_greed_value > 80:
- return True
- if self.news_sentiment is not None and self.news_sentiment < -0.5:
- return True
- return False
+"""Market sentiment aggregation."""
+
+from datetime import datetime
+from typing import ClassVar
+
+from shared.sentiment_models import SymbolScore
+
+
+def _safe_avg(values: list[float]) -> float:
+ if not values:
+ return 0.0
+ return sum(values) / len(values)
+
+
+class SentimentAggregator:
+ """Aggregates per-news sentiment into per-symbol scores."""
+
+ WEIGHTS: ClassVar[dict[str, float]] = {"news": 0.3, "social": 0.2, "policy": 0.3, "filing": 0.2}
+
+ CATEGORY_MAP: ClassVar[dict[str, str]] = {
+ "earnings": "news",
+ "macro": "news",
+ "social": "social",
+ "policy": "policy",
+ "filing": "filing",
+ "fed": "policy",
+ }
+
+ def _freshness_decay(self, published_at: datetime, now: datetime) -> float:
+ age = now - published_at
+ hours = age.total_seconds() / 3600
+ if hours < 1:
+ return 1.0
+ if hours < 6:
+ return 0.7
+ if hours < 24:
+ return 0.3
+ return 0.0
+
+ def _compute_composite(
+ self,
+ news_score: float,
+ social_score: float,
+ policy_score: float,
+ filing_score: float,
+ ) -> float:
+ return (
+ news_score * self.WEIGHTS["news"]
+ + social_score * self.WEIGHTS["social"]
+ + policy_score * self.WEIGHTS["policy"]
+ + filing_score * self.WEIGHTS["filing"]
+ )
+
+ def aggregate(self, news_items: list[dict], now: datetime) -> dict[str, SymbolScore]:
+ """Aggregate news items into per-symbol scores.
+
+ Each dict needs: symbols, sentiment, category, published_at.
+ """
+ symbol_data: dict[str, dict] = {}
+
+ for item in news_items:
+ decay = self._freshness_decay(item["published_at"], now)
+ if decay == 0.0:
+ continue
+ category = item.get("category", "macro")
+ score_field = self.CATEGORY_MAP.get(category, "news")
+ weighted_sentiment = item["sentiment"] * decay
+
+ for symbol in item.get("symbols", []):
+ if symbol not in symbol_data:
+ symbol_data[symbol] = {
+ "news_scores": [],
+ "social_scores": [],
+ "policy_scores": [],
+ "filing_scores": [],
+ "count": 0,
+ }
+ symbol_data[symbol][f"{score_field}_scores"].append(weighted_sentiment)
+ symbol_data[symbol]["count"] += 1
+
+ result = {}
+ for symbol, data in symbol_data.items():
+ ns = _safe_avg(data["news_scores"])
+ ss = _safe_avg(data["social_scores"])
+ ps = _safe_avg(data["policy_scores"])
+ fs = _safe_avg(data["filing_scores"])
+ result[symbol] = SymbolScore(
+ symbol=symbol,
+ news_score=ns,
+ news_count=data["count"],
+ social_score=ss,
+ policy_score=ps,
+ filing_score=fs,
+ composite=self._compute_composite(ns, ss, ps, fs),
+ updated_at=now,
+ )
+ return result
+
+ def determine_regime(self, fear_greed: int, vix: float | None) -> str:
+ if fear_greed <= 20:
+ return "risk_off"
+ if vix is not None and vix > 30:
+ return "risk_off"
+ if fear_greed >= 60 and (vix is None or vix < 20):
+ return "risk_on"
+ return "neutral"
diff --git a/shared/src/shared/sentiment_models.py b/shared/src/shared/sentiment_models.py
new file mode 100644
index 0000000..ac06c20
--- /dev/null
+++ b/shared/src/shared/sentiment_models.py
@@ -0,0 +1,43 @@
+"""Sentiment scoring and stock selection models."""
+
+from datetime import datetime
+
+from pydantic import BaseModel
+
+from shared.models import OrderSide
+
+
+class SymbolScore(BaseModel):
+ symbol: str
+ news_score: float
+ news_count: int
+ social_score: float
+ policy_score: float
+ filing_score: float
+ composite: float
+ updated_at: datetime
+
+
+class MarketSentiment(BaseModel):
+ fear_greed: int
+ fear_greed_label: str
+ vix: float | None = None
+ fed_stance: str
+ market_regime: str
+ updated_at: datetime
+
+
+class SelectedStock(BaseModel):
+ symbol: str
+ side: OrderSide
+ conviction: float
+ reason: str
+ key_news: list[str]
+
+
+class Candidate(BaseModel):
+ symbol: str
+ source: str
+ direction: OrderSide | None = None
+ score: float
+ reason: str
diff --git a/shared/src/shared/shutdown.py b/shared/src/shared/shutdown.py
new file mode 100644
index 0000000..4ed9aa7
--- /dev/null
+++ b/shared/src/shared/shutdown.py
@@ -0,0 +1,30 @@
+"""Graceful shutdown utilities for services."""
+
+import asyncio
+import logging
+import signal
+
+logger = logging.getLogger(__name__)
+
+
+class GracefulShutdown:
+ """Manages graceful shutdown via SIGTERM/SIGINT signals."""
+
+ def __init__(self) -> None:
+ self._event = asyncio.Event()
+
+ @property
+ def is_shutting_down(self) -> bool:
+ return self._event.is_set()
+
+ async def wait(self) -> None:
+ await self._event.wait()
+
+ def trigger(self) -> None:
+ logger.info("shutdown_signal_received")
+ self._event.set()
+
+ def install_handlers(self) -> None:
+ loop = asyncio.get_running_loop()
+ for sig in (signal.SIGTERM, signal.SIGINT):
+ loop.add_signal_handler(sig, self.trigger)
diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py
index 080b7c4..55a2b24 100644
--- a/shared/tests/test_alpaca.py
+++ b/shared/tests/test_alpaca.py
@@ -1,7 +1,9 @@
"""Tests for Alpaca API client."""
-import pytest
from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
from shared.alpaca import AlpacaClient
diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py
index eb1582d..5636611 100644
--- a/shared/tests/test_broker.py
+++ b/shared/tests/test_broker.py
@@ -1,10 +1,11 @@
"""Tests for the Redis broker."""
-import pytest
import json
-import redis
from unittest.mock import AsyncMock, patch
+import pytest
+import redis
+
@pytest.mark.asyncio
async def test_broker_publish():
diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py
new file mode 100644
index 0000000..9376dc6
--- /dev/null
+++ b/shared/tests/test_config_validation.py
@@ -0,0 +1,29 @@
+"""Tests for config validation."""
+
+import pytest
+from pydantic import ValidationError
+
+from shared.config import Settings
+
+
+class TestConfigValidation:
+ def test_valid_defaults(self):
+ settings = Settings()
+ assert settings.risk_max_position_size == 0.1
+
+ def test_invalid_position_size(self):
+ with pytest.raises(ValidationError, match="risk_max_position_size"):
+ Settings(risk_max_position_size=-0.1)
+
+ def test_invalid_health_port(self):
+ with pytest.raises(ValidationError, match="health_port"):
+ Settings(health_port=80)
+
+ def test_invalid_log_level(self):
+ with pytest.raises(ValidationError, match="log_level"):
+ Settings(log_level="INVALID")
+
+ def test_secret_fields_masked(self):
+ settings = Settings(alpaca_api_key="my-secret-key")
+ assert "my-secret-key" not in repr(settings)
+ assert settings.alpaca_api_key.get_secret_value() == "my-secret-key"
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 239ee64..b44a713 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,10 +1,11 @@
"""Tests for the SQLAlchemy async database layer."""
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
+import pytest
+
def make_candle():
from shared.models import Candle
@@ -12,7 +13,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -22,7 +23,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
id="sig-1",
@@ -32,12 +33,12 @@ def make_signal():
price=Decimal("50000"),
quantity=Decimal("0.1"),
reason="Golden cross",
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def make_order():
- from shared.models import Order, OrderSide, OrderType, OrderStatus
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
return Order(
id="ord-1",
@@ -48,7 +49,7 @@ def make_order():
price=Decimal("50000"),
quantity=Decimal("0.1"),
status=OrderStatus.PENDING,
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
@@ -101,6 +102,54 @@ class TestDatabaseConnect:
mock_create.assert_called_once()
@pytest.mark.asyncio
+ async def test_connect_passes_pool_params_for_postgres(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800)
+ mock_create.assert_called_once_with(
+ "postgresql+asyncpg://host/db",
+ pool_pre_ping=True,
+ pool_size=5,
+ max_overflow=3,
+ pool_recycle=1800,
+ )
+
+ @pytest.mark.asyncio
+ async def test_connect_skips_pool_params_for_sqlite(self):
+ from shared.db import Database
+
+ db = Database("sqlite+aiosqlite:///test.db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect()
+ mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db")
+
+ @pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database
@@ -211,7 +260,7 @@ class TestUpdateOrderStatus:
db._session_factory = MagicMock(return_value=mock_session)
- filled = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ filled = datetime(2024, 1, 2, tzinfo=UTC)
await db.update_order_status("ord-1", OrderStatus.FILLED, filled)
mock_session.execute.assert_awaited_once()
@@ -230,7 +279,7 @@ class TestGetCandles:
mock_row._mapping = {
"symbol": "AAPL",
"timeframe": "1m",
- "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc),
+ "open_time": datetime(2024, 1, 1, tzinfo=UTC),
"open": Decimal("50000"),
"high": Decimal("51000"),
"low": Decimal("49500"),
@@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots:
mock_row.total_value = Decimal("10000")
mock_row.realized_pnl = Decimal("0")
mock_row.unrealized_pnl = Decimal("500")
- mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_row]
diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py
new file mode 100644
index 0000000..c184bed
--- /dev/null
+++ b/shared/tests/test_db_news.py
@@ -0,0 +1,79 @@
+"""Tests for database news/sentiment methods. Uses in-memory SQLite."""
+
+from datetime import UTC, date, datetime
+
+import pytest
+
+from shared.db import Database
+from shared.models import NewsCategory, NewsItem
+from shared.sentiment_models import MarketSentiment, SymbolScore
+
+
+@pytest.fixture
+async def db():
+ database = Database("sqlite+aiosqlite://")
+ await database.connect()
+ yield database
+ await database.close()
+
+
+async def test_insert_and_get_news_items(db):
+ item = NewsItem(
+ source="finnhub",
+ headline="AAPL earnings beat",
+ published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
+ sentiment=0.8,
+ category=NewsCategory.EARNINGS,
+ symbols=["AAPL"],
+ )
+ await db.insert_news_item(item)
+ items = await db.get_recent_news(hours=24)
+ assert len(items) == 1
+ assert items[0]["headline"] == "AAPL earnings beat"
+
+
+async def test_upsert_symbol_score(db):
+ score = SymbolScore(
+ symbol="AAPL",
+ news_score=0.5,
+ news_count=10,
+ social_score=0.3,
+ policy_score=0.0,
+ filing_score=0.2,
+ composite=0.3,
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
+ )
+ await db.upsert_symbol_score(score)
+ scores = await db.get_top_symbol_scores(limit=5)
+ assert len(scores) == 1
+ assert scores[0]["symbol"] == "AAPL"
+
+
+async def test_upsert_market_sentiment(db):
+ ms = MarketSentiment(
+ fear_greed=55,
+ fear_greed_label="Neutral",
+ vix=18.2,
+ fed_stance="neutral",
+ market_regime="neutral",
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
+ )
+ await db.upsert_market_sentiment(ms)
+ result = await db.get_latest_market_sentiment()
+ assert result is not None
+ assert result["fear_greed"] == 55
+
+
+async def test_insert_stock_selection(db):
+ await db.insert_stock_selection(
+ trade_date=date(2026, 4, 2),
+ symbol="NVDA",
+ side="BUY",
+ conviction=0.85,
+ reason="CHIPS Act",
+ key_news=["Trump signs CHIPS expansion"],
+ sentiment_snapshot={"composite": 0.8},
+ )
+ selections = await db.get_stock_selections(date(2026, 4, 2))
+ assert len(selections) == 1
+ assert selections[0]["symbol"] == "NVDA"
diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py
index 6077d93..1ccd904 100644
--- a/shared/tests/test_events.py
+++ b/shared/tests/test_events.py
@@ -1,7 +1,7 @@
"""Tests for shared event types."""
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
def make_candle():
@@ -10,7 +10,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -20,7 +20,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
strategy="test",
@@ -59,7 +59,7 @@ def test_candle_event_deserialize():
def test_signal_event_serialize():
"""Test SignalEvent serializes to dict correctly."""
- from shared.events import SignalEvent, EventType
+ from shared.events import EventType, SignalEvent
signal = make_signal()
event = SignalEvent(data=signal)
@@ -71,7 +71,7 @@ def test_signal_event_serialize():
def test_event_from_dict_dispatch():
"""Test Event.from_dict dispatches to correct class."""
- from shared.events import Event, CandleEvent, SignalEvent
+ from shared.events import CandleEvent, Event, SignalEvent
candle = make_candle()
event = CandleEvent(data=candle)
diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py
index 04098ce..40bb791 100644
--- a/shared/tests/test_models.py
+++ b/shared/tests/test_models.py
@@ -1,8 +1,8 @@
"""Tests for shared models and settings."""
import os
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import patch
@@ -12,8 +12,11 @@ def test_settings_defaults():
with patch.dict(os.environ, {}, clear=False):
settings = Settings()
- assert settings.redis_url == "redis://localhost:6379"
- assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading"
+ assert settings.redis_url.get_secret_value() == "redis://localhost:6379"
+ assert (
+ settings.database_url.get_secret_value()
+ == "postgresql://trading:trading@localhost:5432/trading"
+ )
assert settings.log_level == "INFO"
assert settings.risk_max_position_size == 0.1
assert settings.risk_stop_loss_pct == 5.0
@@ -25,7 +28,7 @@ def test_candle_creation():
"""Test Candle model creation."""
from shared.models import Candle
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
candle = Candle(
symbol="AAPL",
timeframe="1m",
@@ -47,7 +50,7 @@ def test_candle_creation():
def test_signal_creation():
"""Test Signal model creation."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi_strategy",
@@ -69,9 +72,10 @@ def test_signal_creation():
def test_order_creation():
"""Test Order model creation with defaults."""
- from shared.models import Order, OrderSide, OrderType, OrderStatus
import uuid
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
+
signal_id = str(uuid.uuid4())
order = Order(
signal_id=signal_id,
@@ -90,7 +94,7 @@ def test_order_creation():
def test_signal_conviction_default():
"""Test Signal defaults for conviction, stop_loss, take_profit."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi",
@@ -107,7 +111,7 @@ def test_signal_conviction_default():
def test_signal_with_stops():
"""Test Signal with explicit conviction, stop_loss, take_profit."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi",
diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py
new file mode 100644
index 0000000..f748d8a
--- /dev/null
+++ b/shared/tests/test_news_events.py
@@ -0,0 +1,56 @@
+"""Tests for NewsEvent."""
+
+from datetime import UTC, datetime
+
+from shared.events import Event, EventType, NewsEvent
+from shared.models import NewsCategory, NewsItem
+
+
+def test_news_event_to_dict():
+ item = NewsItem(
+ source="finnhub",
+ headline="Test",
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
+ sentiment=0.5,
+ category=NewsCategory.MACRO,
+ )
+ event = NewsEvent(data=item)
+ d = event.to_dict()
+ assert d["type"] == EventType.NEWS
+ assert d["data"]["source"] == "finnhub"
+
+
+def test_news_event_from_raw():
+ raw = {
+ "type": "NEWS",
+ "data": {
+ "id": "abc",
+ "source": "rss",
+ "headline": "Test headline",
+ "published_at": "2026-04-02T00:00:00+00:00",
+ "sentiment": 0.3,
+ "category": "earnings",
+ "symbols": ["AAPL"],
+ "raw_data": {},
+ },
+ }
+ event = NewsEvent.from_raw(raw)
+ assert event.data.source == "rss"
+ assert event.data.symbols == ["AAPL"]
+
+
+def test_event_dispatcher_news():
+ raw = {
+ "type": "NEWS",
+ "data": {
+ "id": "abc",
+ "source": "finnhub",
+ "headline": "Test",
+ "published_at": "2026-04-02T00:00:00+00:00",
+ "sentiment": 0.0,
+ "category": "macro",
+ "raw_data": {},
+ },
+ }
+ event = Event.from_dict(raw)
+ assert isinstance(event, NewsEvent)
diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py
index 6c81369..cc98a56 100644
--- a/shared/tests/test_notifier.py
+++ b/shared/tests/test_notifier.py
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position
+from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal
from shared.notifier import TelegramNotifier
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
index e287777..e0781af 100644
--- a/shared/tests/test_resilience.py
+++ b/shared/tests/test_resilience.py
@@ -1,139 +1,176 @@
-"""Tests for retry with backoff and circuit breaker."""
+"""Tests for shared.resilience module."""
-import time
+import asyncio
import pytest
-from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff
+from shared.resilience import CircuitBreaker, async_timeout, retry_async
+# --- retry_async tests ---
-# ---------------------------------------------------------------------------
-# retry_with_backoff tests
-# ---------------------------------------------------------------------------
-
-@pytest.mark.asyncio
-async def test_retry_succeeds_first_try():
+async def test_succeeds_without_retry():
+ """Function succeeds first try, called once."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def succeed():
+ @retry_async()
+ async def fn():
nonlocal call_count
call_count += 1
return "ok"
- result = await succeed()
+ result = await fn()
assert result == "ok"
assert call_count == 1
-@pytest.mark.asyncio
-async def test_retry_succeeds_after_failures():
+async def test_retries_on_failure_then_succeeds():
+ """Fails twice then succeeds, verify call count."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def flaky():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
if call_count < 3:
- raise ValueError("not yet")
+ raise RuntimeError("transient")
return "recovered"
- result = await flaky()
+ result = await fn()
assert result == "recovered"
assert call_count == 3
-@pytest.mark.asyncio
-async def test_retry_raises_after_max_retries():
+async def test_raises_after_max_retries():
+ """Always fails, raises after max retries."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def always_fail():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
- raise RuntimeError("permanent")
+ raise ValueError("permanent")
- with pytest.raises(RuntimeError, match="permanent"):
- await always_fail()
- # 1 initial + 3 retries = 4 calls
+ with pytest.raises(ValueError, match="permanent"):
+ await fn()
+
+ # 1 initial + 3 retries = 4 total calls
assert call_count == 4
-@pytest.mark.asyncio
-async def test_retry_respects_max_delay():
- """Backoff should be capped at max_delay."""
+async def test_no_retry_on_excluded_exception():
+ """Excluded exception raises immediately, call count = 1."""
+ call_count = 0
- @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
- async def always_fail():
- raise RuntimeError("fail")
+ @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,))
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ raise TypeError("excluded")
- start = time.monotonic()
- with pytest.raises(RuntimeError):
- await always_fail()
- elapsed = time.monotonic() - start
- # With max_delay=0.02 and 2 retries, total delay should be small
- assert elapsed < 0.5
+ with pytest.raises(TypeError, match="excluded"):
+ await fn()
+
+ assert call_count == 1
-# ---------------------------------------------------------------------------
-# CircuitBreaker tests
-# ---------------------------------------------------------------------------
+# --- CircuitBreaker tests ---
-def test_circuit_starts_closed():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05)
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_closed_allows_calls():
+ """CircuitBreaker in closed state passes through."""
+ cb = CircuitBreaker(failure_threshold=5, cooldown=60.0)
+ async def fn():
+ return "ok"
+
+ result = await cb.call(fn)
+ assert result == "ok"
+
+
+async def test_opens_after_threshold():
+ """After N failures, raises RuntimeError."""
+ cb = CircuitBreaker(failure_threshold=3, cooldown=60.0)
+
+ async def fail():
+ raise RuntimeError("fail")
-def test_circuit_opens_after_threshold():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0)
for _ in range(3):
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+ # Now the breaker should be open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+
+async def test_half_open_after_cooldown():
+ """After cooldown, allows recovery attempt."""
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def fail():
+ raise RuntimeError("fail")
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+
+ # Breaker is open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.06)
+
+ # Now should allow a call (half_open). Succeed to close it.
+ async def succeed():
+ return "recovered"
+
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+ # Breaker should be closed again
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+
+async def test_half_open_reopens_on_failure():
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def always_fail():
+ raise ConnectionError("fail")
-def test_circuit_rejects_when_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
+ # Wait for cooldown
+ await asyncio.sleep(0.1)
-def test_circuit_half_open_after_timeout():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+ # Half-open probe should fail and re-open
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
- time.sleep(0.06)
- assert cb.allow_request() is True
- assert cb.state == CircuitState.HALF_OPEN
+ # Should be open again (no cooldown wait)
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(always_fail)
-def test_circuit_closes_on_success():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+# --- async_timeout tests ---
- time.sleep(0.06)
- cb.allow_request() # triggers HALF_OPEN
- assert cb.state == CircuitState.HALF_OPEN
- cb.record_success()
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_completes_within_timeout():
+ """async_timeout doesn't interfere with fast operations."""
+ async with async_timeout(1.0):
+ await asyncio.sleep(0.01)
+ result = 42
+ assert result == 42
-def test_circuit_reopens_on_failure_in_half_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- time.sleep(0.06)
- cb.allow_request() # HALF_OPEN
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+async def test_raises_on_timeout():
+ """async_timeout raises TimeoutError for slow operations."""
+ with pytest.raises(TimeoutError):
+ async with async_timeout(0.05):
+ await asyncio.sleep(1.0)
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py
index 67c3c82..c9311dd 100644
--- a/shared/tests/test_sa_models.py
+++ b/shared/tests/test_sa_models.py
@@ -14,6 +14,10 @@ def test_base_metadata_has_all_tables():
"trades",
"positions",
"portfolio_snapshots",
+ "news_items",
+ "symbol_scores",
+ "market_sentiment",
+ "stock_selections",
}
assert expected == table_names
@@ -68,6 +72,9 @@ class TestSignalRow:
"price",
"quantity",
"reason",
+ "conviction",
+ "stop_loss",
+ "take_profit",
"created_at",
}
assert expected == cols
@@ -120,44 +127,6 @@ class TestOrderRow:
assert fk_cols == {"signal_id": "signals.id"}
-class TestTradeRow:
- def test_table_name(self):
- from shared.sa_models import TradeRow
-
- assert TradeRow.__tablename__ == "trades"
-
- def test_columns(self):
- from shared.sa_models import TradeRow
-
- mapper = inspect(TradeRow)
- cols = {c.key for c in mapper.column_attrs}
- expected = {
- "id",
- "order_id",
- "symbol",
- "side",
- "price",
- "quantity",
- "fee",
- "traded_at",
- }
- assert expected == cols
-
- def test_primary_key(self):
- from shared.sa_models import TradeRow
-
- mapper = inspect(TradeRow)
- pk_cols = [c.name for c in mapper.mapper.primary_key]
- assert pk_cols == ["id"]
-
- def test_order_id_foreign_key(self):
- from shared.sa_models import TradeRow
-
- table = TradeRow.__table__
- fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys}
- assert fk_cols == {"order_id": "orders.id"}
-
-
class TestPositionRow:
def test_table_name(self):
from shared.sa_models import PositionRow
@@ -229,11 +198,3 @@ class TestStatusDefault:
status_col = table.c.status
assert status_col.server_default is not None
assert status_col.server_default.arg == "PENDING"
-
- def test_trade_fee_server_default(self):
- from shared.sa_models import TradeRow
-
- table = TradeRow.__table__
- fee_col = table.c.fee
- assert fee_col.server_default is not None
- assert fee_col.server_default.arg == "0"
diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py
new file mode 100644
index 0000000..dc2d026
--- /dev/null
+++ b/shared/tests/test_sa_news_models.py
@@ -0,0 +1,29 @@
+"""Tests for news-related SQLAlchemy models."""
+
+from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow
+
+
+def test_news_item_row_tablename():
+ assert NewsItemRow.__tablename__ == "news_items"
+
+
+def test_symbol_score_row_tablename():
+ assert SymbolScoreRow.__tablename__ == "symbol_scores"
+
+
+def test_market_sentiment_row_tablename():
+ assert MarketSentimentRow.__tablename__ == "market_sentiment"
+
+
+def test_stock_selection_row_tablename():
+ assert StockSelectionRow.__tablename__ == "stock_selections"
+
+
+def test_news_item_row_columns():
+ cols = {c.name for c in NewsItemRow.__table__.columns}
+ assert cols >= {"id", "source", "headline", "published_at", "sentiment", "category"}
+
+
+def test_symbol_score_row_columns():
+ cols = {c.name for c in SymbolScoreRow.__table__.columns}
+ assert cols >= {"id", "symbol", "news_score", "composite", "updated_at"}
diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py
deleted file mode 100644
index 9bd8ea3..0000000
--- a/shared/tests/test_sentiment.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Tests for market sentiment module."""
-
-from shared.sentiment import SentimentData
-
-
-def test_sentiment_should_buy_default_no_data():
- s = SentimentData()
- assert s.should_buy is True
- assert s.should_block is False
-
-
-def test_sentiment_should_buy_low_fear_greed():
- s = SentimentData(fear_greed_value=15)
- assert s.should_buy is True
-
-
-def test_sentiment_should_not_buy_on_greed():
- s = SentimentData(fear_greed_value=75)
- assert s.should_buy is False
-
-
-def test_sentiment_should_not_buy_negative_news():
- s = SentimentData(news_sentiment=-0.4)
- assert s.should_buy is False
-
-
-def test_sentiment_should_buy_positive_news():
- s = SentimentData(fear_greed_value=50, news_sentiment=0.3)
- assert s.should_buy is True
-
-
-def test_sentiment_should_block_extreme_greed():
- s = SentimentData(fear_greed_value=85)
- assert s.should_block is True
-
-
-def test_sentiment_should_block_very_negative_news():
- s = SentimentData(news_sentiment=-0.6)
- assert s.should_block is True
-
-
-def test_sentiment_no_block_on_neutral():
- s = SentimentData(fear_greed_value=50, news_sentiment=0.0)
- assert s.should_block is False
diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py
new file mode 100644
index 0000000..9193785
--- /dev/null
+++ b/shared/tests/test_sentiment_aggregator.py
@@ -0,0 +1,79 @@
+"""Tests for sentiment aggregator."""
+
+from datetime import UTC, datetime, timedelta
+
+import pytest
+
+from shared.sentiment import SentimentAggregator
+
+
+@pytest.fixture
+def aggregator():
+ return SentimentAggregator()
+
+
+def test_freshness_decay_recent():
+ a = SentimentAggregator()
+ now = datetime.now(UTC)
+ assert a._freshness_decay(now, now) == 1.0
+
+
+def test_freshness_decay_3_hours():
+ a = SentimentAggregator()
+ now = datetime.now(UTC)
+ assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7
+
+
+def test_freshness_decay_12_hours():
+ a = SentimentAggregator()
+ now = datetime.now(UTC)
+ assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3
+
+
+def test_freshness_decay_old():
+ a = SentimentAggregator()
+ now = datetime.now(UTC)
+ assert a._freshness_decay(now - timedelta(days=2), now) == 0.0
+
+
+def test_compute_composite():
+ a = SentimentAggregator()
+ composite = a._compute_composite(
+ news_score=0.5, social_score=0.3, policy_score=0.8, filing_score=0.2
+ )
+ expected = 0.5 * 0.3 + 0.3 * 0.2 + 0.8 * 0.3 + 0.2 * 0.2
+ assert abs(composite - expected) < 0.001
+
+
+def test_aggregate_news_by_symbol(aggregator):
+ now = datetime.now(UTC)
+ news_items = [
+ {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now},
+ {
+ "symbols": ["AAPL"],
+ "sentiment": 0.3,
+ "category": "macro",
+ "published_at": now - timedelta(hours=2),
+ },
+ {"symbols": ["MSFT"], "sentiment": -0.5, "category": "policy", "published_at": now},
+ ]
+ scores = aggregator.aggregate(news_items, now)
+ assert "AAPL" in scores
+ assert "MSFT" in scores
+ assert scores["AAPL"].news_count == 2
+ assert scores["AAPL"].news_score > 0
+ assert scores["MSFT"].policy_score < 0
+
+
+def test_aggregate_empty(aggregator):
+ now = datetime.now(UTC)
+ assert aggregator.aggregate([], now) == {}
+
+
+def test_determine_regime():
+ a = SentimentAggregator()
+ assert a.determine_regime(15, None) == "risk_off"
+ assert a.determine_regime(15, 35.0) == "risk_off"
+ assert a.determine_regime(50, 35.0) == "risk_off"
+ assert a.determine_regime(70, 15.0) == "risk_on"
+ assert a.determine_regime(50, 20.0) == "neutral"
diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py
new file mode 100644
index 0000000..e00ffa6
--- /dev/null
+++ b/shared/tests/test_sentiment_models.py
@@ -0,0 +1,113 @@
+"""Tests for news and sentiment models."""
+
+from datetime import UTC, datetime
+
+from shared.models import NewsCategory, NewsItem, OrderSide
+from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore
+
+
+def test_news_item_defaults():
+ item = NewsItem(
+ source="finnhub",
+ headline="Test headline",
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
+ sentiment=0.5,
+ category=NewsCategory.MACRO,
+ )
+ assert item.id
+ assert item.symbols == []
+ assert item.summary is None
+ assert item.raw_data == {}
+ assert item.created_at is not None
+
+
+def test_news_item_with_symbols():
+ item = NewsItem(
+ source="rss",
+ headline="AAPL earnings beat",
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
+ sentiment=0.8,
+ category=NewsCategory.EARNINGS,
+ symbols=["AAPL"],
+ )
+ assert item.symbols == ["AAPL"]
+ assert item.category == NewsCategory.EARNINGS
+
+
+def test_news_category_values():
+ assert NewsCategory.POLICY == "policy"
+ assert NewsCategory.EARNINGS == "earnings"
+ assert NewsCategory.MACRO == "macro"
+ assert NewsCategory.SOCIAL == "social"
+ assert NewsCategory.FILING == "filing"
+ assert NewsCategory.FED == "fed"
+
+
+def test_symbol_score():
+ score = SymbolScore(
+ symbol="AAPL",
+ news_score=0.5,
+ news_count=10,
+ social_score=0.3,
+ policy_score=0.0,
+ filing_score=0.2,
+ composite=0.3,
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
+ )
+ assert score.symbol == "AAPL"
+ assert score.composite == 0.3
+
+
+def test_market_sentiment():
+ ms = MarketSentiment(
+ fear_greed=25,
+ fear_greed_label="Extreme Fear",
+ vix=32.5,
+ fed_stance="hawkish",
+ market_regime="risk_off",
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
+ )
+ assert ms.market_regime == "risk_off"
+ assert ms.vix == 32.5
+
+
+def test_market_sentiment_no_vix():
+ ms = MarketSentiment(
+ fear_greed=50,
+ fear_greed_label="Neutral",
+ fed_stance="neutral",
+ market_regime="neutral",
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
+ )
+ assert ms.vix is None
+
+
+def test_selected_stock():
+ ss = SelectedStock(
+ symbol="NVDA",
+ side=OrderSide.BUY,
+ conviction=0.85,
+ reason="CHIPS Act expansion",
+ key_news=["Trump signs CHIPS Act expansion"],
+ )
+ assert ss.conviction == 0.85
+ assert len(ss.key_news) == 1
+
+
+def test_candidate():
+ c = Candidate(
+ symbol="TSLA",
+ source="sentiment",
+ direction=OrderSide.BUY,
+ score=0.75,
+ reason="High social buzz",
+ )
+ assert c.direction == OrderSide.BUY
+
+ c2 = Candidate(
+ symbol="XOM",
+ source="llm",
+ score=0.6,
+ reason="Oil price surge",
+ )
+ assert c2.direction is None