diff options
Diffstat (limited to 'shared/tests')
| -rw-r--r-- | shared/tests/test_alpaca.py | 4 | ||||
| -rw-r--r-- | shared/tests/test_broker.py | 11 | ||||
| -rw-r--r-- | shared/tests/test_config_validation.py | 29 | ||||
| -rw-r--r-- | shared/tests/test_db.py | 69 | ||||
| -rw-r--r-- | shared/tests/test_db_news.py | 79 | ||||
| -rw-r--r-- | shared/tests/test_events.py | 10 | ||||
| -rw-r--r-- | shared/tests/test_models.py | 20 | ||||
| -rw-r--r-- | shared/tests/test_news_events.py | 56 | ||||
| -rw-r--r-- | shared/tests/test_notifier.py | 14 | ||||
| -rw-r--r-- | shared/tests/test_resilience.py | 203 | ||||
| -rw-r--r-- | shared/tests/test_sa_models.py | 53 | ||||
| -rw-r--r-- | shared/tests/test_sa_news_models.py | 29 | ||||
| -rw-r--r-- | shared/tests/test_sentiment.py | 44 | ||||
| -rw-r--r-- | shared/tests/test_sentiment_aggregator.py | 79 | ||||
| -rw-r--r-- | shared/tests/test_sentiment_models.py | 113 |
15 files changed, 604 insertions, 209 deletions
diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py index 080b7c4..55a2b24 100644 --- a/shared/tests/test_alpaca.py +++ b/shared/tests/test_alpaca.py @@ -1,7 +1,9 @@ """Tests for Alpaca API client.""" -import pytest from unittest.mock import AsyncMock, MagicMock + +import pytest + from shared.alpaca import AlpacaClient diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py index 9be84b0..5636611 100644 --- a/shared/tests/test_broker.py +++ b/shared/tests/test_broker.py @@ -1,10 +1,11 @@ """Tests for the Redis broker.""" -import pytest import json -import redis from unittest.mock import AsyncMock, patch +import pytest +import redis + @pytest.mark.asyncio async def test_broker_publish(): @@ -16,7 +17,7 @@ async def test_broker_publish(): from shared.broker import RedisBroker broker = RedisBroker("redis://localhost:6379") - data = {"type": "CANDLE", "symbol": "BTCUSDT"} + data = {"type": "CANDLE", "symbol": "AAPL"} await broker.publish("candles", data) mock_redis.xadd.assert_called_once() @@ -35,7 +36,7 @@ async def test_broker_subscribe_returns_messages(): mock_redis = AsyncMock() mock_from_url.return_value = mock_redis - payload_data = {"type": "CANDLE", "symbol": "ETHUSDT"} + payload_data = {"type": "CANDLE", "symbol": "MSFT"} mock_redis.xread.return_value = [ [ b"candles", @@ -53,7 +54,7 @@ async def test_broker_subscribe_returns_messages(): mock_redis.xread.assert_called_once() assert len(messages) == 1 assert messages[0]["type"] == "CANDLE" - assert messages[0]["symbol"] == "ETHUSDT" + assert messages[0]["symbol"] == "MSFT" @pytest.mark.asyncio diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py new file mode 100644 index 0000000..9376dc6 --- /dev/null +++ b/shared/tests/test_config_validation.py @@ -0,0 +1,29 @@ +"""Tests for config validation.""" + +import pytest +from pydantic import ValidationError + +from shared.config import Settings + + +class TestConfigValidation: + def test_valid_defaults(self): + settings = Settings() + assert settings.risk_max_position_size == 0.1 + + def test_invalid_position_size(self): + with pytest.raises(ValidationError, match="risk_max_position_size"): + Settings(risk_max_position_size=-0.1) + + def test_invalid_health_port(self): + with pytest.raises(ValidationError, match="health_port"): + Settings(health_port=80) + + def test_invalid_log_level(self): + with pytest.raises(ValidationError, match="log_level"): + Settings(log_level="INVALID") + + def test_secret_fields_masked(self): + settings = Settings(alpaca_api_key="my-secret-key") + assert "my-secret-key" not in repr(settings) + assert settings.alpaca_api_key.get_secret_value() == "my-secret-key" diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 239ee64..b44a713 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,10 +1,11 @@ """Tests for the SQLAlchemy async database layer.""" -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch +import pytest + def make_candle(): from shared.models import Candle @@ -12,7 +13,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -22,7 +23,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( id="sig-1", @@ -32,12 +33,12 @@ def make_signal(): price=Decimal("50000"), quantity=Decimal("0.1"), reason="Golden cross", - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) def make_order(): - from shared.models import Order, OrderSide, OrderType, OrderStatus + from shared.models import Order, OrderSide, OrderStatus, OrderType return Order( id="ord-1", @@ -48,7 +49,7 @@ def make_order(): price=Decimal("50000"), quantity=Decimal("0.1"), status=OrderStatus.PENDING, - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) @@ -101,6 +102,54 @@ class TestDatabaseConnect: mock_create.assert_called_once() @pytest.mark.asyncio + async def test_connect_passes_pool_params_for_postgres(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800) + mock_create.assert_called_once_with( + "postgresql+asyncpg://host/db", + pool_pre_ping=True, + pool_size=5, + max_overflow=3, + pool_recycle=1800, + ) + + @pytest.mark.asyncio + async def test_connect_skips_pool_params_for_sqlite(self): + from shared.db import Database + + db = Database("sqlite+aiosqlite:///test.db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db") + + @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database @@ -211,7 +260,7 @@ class TestUpdateOrderStatus: db._session_factory = MagicMock(return_value=mock_session) - filled = datetime(2024, 1, 2, tzinfo=timezone.utc) + filled = datetime(2024, 1, 2, tzinfo=UTC) await db.update_order_status("ord-1", OrderStatus.FILLED, filled) mock_session.execute.assert_awaited_once() @@ -230,7 +279,7 @@ class TestGetCandles: mock_row._mapping = { "symbol": "AAPL", "timeframe": "1m", - "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), + "open_time": datetime(2024, 1, 1, tzinfo=UTC), "open": Decimal("50000"), "high": Decimal("51000"), "low": Decimal("49500"), @@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots: mock_row.total_value = Decimal("10000") mock_row.realized_pnl = Decimal("0") mock_row.unrealized_pnl = Decimal("500") - mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_row] diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py new file mode 100644 index 0000000..c184bed --- /dev/null +++ b/shared/tests/test_db_news.py @@ -0,0 +1,79 @@ +"""Tests for database news/sentiment methods. Uses in-memory SQLite.""" + +from datetime import UTC, date, datetime + +import pytest + +from shared.db import Database +from shared.models import NewsCategory, NewsItem +from shared.sentiment_models import MarketSentiment, SymbolScore + + +@pytest.fixture +async def db(): + database = Database("sqlite+aiosqlite://") + await database.connect() + yield database + await database.close() + + +async def test_insert_and_get_news_items(db): + item = NewsItem( + source="finnhub", + headline="AAPL earnings beat", + published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), + sentiment=0.8, + category=NewsCategory.EARNINGS, + symbols=["AAPL"], + ) + await db.insert_news_item(item) + items = await db.get_recent_news(hours=24) + assert len(items) == 1 + assert items[0]["headline"] == "AAPL earnings beat" + + +async def test_upsert_symbol_score(db): + score = SymbolScore( + symbol="AAPL", + news_score=0.5, + news_count=10, + social_score=0.3, + policy_score=0.0, + filing_score=0.2, + composite=0.3, + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + await db.upsert_symbol_score(score) + scores = await db.get_top_symbol_scores(limit=5) + assert len(scores) == 1 + assert scores[0]["symbol"] == "AAPL" + + +async def test_upsert_market_sentiment(db): + ms = MarketSentiment( + fear_greed=55, + fear_greed_label="Neutral", + vix=18.2, + fed_stance="neutral", + market_regime="neutral", + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + await db.upsert_market_sentiment(ms) + result = await db.get_latest_market_sentiment() + assert result is not None + assert result["fear_greed"] == 55 + + +async def test_insert_stock_selection(db): + await db.insert_stock_selection( + trade_date=date(2026, 4, 2), + symbol="NVDA", + side="BUY", + conviction=0.85, + reason="CHIPS Act", + key_news=["Trump signs CHIPS expansion"], + sentiment_snapshot={"composite": 0.8}, + ) + selections = await db.get_stock_selections(date(2026, 4, 2)) + assert len(selections) == 1 + assert selections[0]["symbol"] == "NVDA" diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py index 6077d93..1ccd904 100644 --- a/shared/tests/test_events.py +++ b/shared/tests/test_events.py @@ -1,7 +1,7 @@ """Tests for shared event types.""" +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone def make_candle(): @@ -10,7 +10,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -20,7 +20,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( strategy="test", @@ -59,7 +59,7 @@ def test_candle_event_deserialize(): def test_signal_event_serialize(): """Test SignalEvent serializes to dict correctly.""" - from shared.events import SignalEvent, EventType + from shared.events import EventType, SignalEvent signal = make_signal() event = SignalEvent(data=signal) @@ -71,7 +71,7 @@ def test_signal_event_serialize(): def test_event_from_dict_dispatch(): """Test Event.from_dict dispatches to correct class.""" - from shared.events import Event, CandleEvent, SignalEvent + from shared.events import CandleEvent, Event, SignalEvent candle = make_candle() event = CandleEvent(data=candle) diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index 04098ce..40bb791 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -1,8 +1,8 @@ """Tests for shared models and settings.""" import os +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import patch @@ -12,8 +12,11 @@ def test_settings_defaults(): with patch.dict(os.environ, {}, clear=False): settings = Settings() - assert settings.redis_url == "redis://localhost:6379" - assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" + assert settings.redis_url.get_secret_value() == "redis://localhost:6379" + assert ( + settings.database_url.get_secret_value() + == "postgresql://trading:trading@localhost:5432/trading" + ) assert settings.log_level == "INFO" assert settings.risk_max_position_size == 0.1 assert settings.risk_stop_loss_pct == 5.0 @@ -25,7 +28,7 @@ def test_candle_creation(): """Test Candle model creation.""" from shared.models import Candle - now = datetime.now(timezone.utc) + now = datetime.now(UTC) candle = Candle( symbol="AAPL", timeframe="1m", @@ -47,7 +50,7 @@ def test_candle_creation(): def test_signal_creation(): """Test Signal model creation.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi_strategy", @@ -69,9 +72,10 @@ def test_signal_creation(): def test_order_creation(): """Test Order model creation with defaults.""" - from shared.models import Order, OrderSide, OrderType, OrderStatus import uuid + from shared.models import Order, OrderSide, OrderStatus, OrderType + signal_id = str(uuid.uuid4()) order = Order( signal_id=signal_id, @@ -90,7 +94,7 @@ def test_order_creation(): def test_signal_conviction_default(): """Test Signal defaults for conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", @@ -107,7 +111,7 @@ def test_signal_conviction_default(): def test_signal_with_stops(): """Test Signal with explicit conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py new file mode 100644 index 0000000..f748d8a --- /dev/null +++ b/shared/tests/test_news_events.py @@ -0,0 +1,56 @@ +"""Tests for NewsEvent.""" + +from datetime import UTC, datetime + +from shared.events import Event, EventType, NewsEvent +from shared.models import NewsCategory, NewsItem + + +def test_news_event_to_dict(): + item = NewsItem( + source="finnhub", + headline="Test", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + event = NewsEvent(data=item) + d = event.to_dict() + assert d["type"] == EventType.NEWS + assert d["data"]["source"] == "finnhub" + + +def test_news_event_from_raw(): + raw = { + "type": "NEWS", + "data": { + "id": "abc", + "source": "rss", + "headline": "Test headline", + "published_at": "2026-04-02T00:00:00+00:00", + "sentiment": 0.3, + "category": "earnings", + "symbols": ["AAPL"], + "raw_data": {}, + }, + } + event = NewsEvent.from_raw(raw) + assert event.data.source == "rss" + assert event.data.symbols == ["AAPL"] + + +def test_event_dispatcher_news(): + raw = { + "type": "NEWS", + "data": { + "id": "abc", + "source": "finnhub", + "headline": "Test", + "published_at": "2026-04-02T00:00:00+00:00", + "sentiment": 0.0, + "category": "macro", + "raw_data": {}, + }, + } + event = Event.from_dict(raw) + assert isinstance(event, NewsEvent) diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py index 3d29830..cc98a56 100644 --- a/shared/tests/test_notifier.py +++ b/shared/tests/test_notifier.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position +from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal from shared.notifier import TelegramNotifier @@ -86,7 +86,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") signal = Signal( strategy="rsi_strategy", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000.00"), quantity=Decimal("0.01"), @@ -99,7 +99,7 @@ class TestTelegramNotifierFormatters: msg = mock_send.call_args[0][0] assert "BUY" in msg assert "rsi_strategy" in msg - assert "BTCUSDT" in msg + assert "AAPL" in msg assert "50000.00" in msg assert "0.01" in msg assert "RSI oversold" in msg @@ -109,7 +109,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") order = Order( signal_id=str(uuid.uuid4()), - symbol="ETHUSDT", + symbol="MSFT", side=OrderSide.SELL, type=OrderType.LIMIT, price=Decimal("3000.50"), @@ -122,7 +122,7 @@ class TestTelegramNotifierFormatters: mock_send.assert_called_once() msg = mock_send.call_args[0][0] assert "FILLED" in msg - assert "ETHUSDT" in msg + assert "MSFT" in msg assert "SELL" in msg assert "3000.50" in msg assert "1.5" in msg @@ -143,7 +143,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") positions = [ Position( - symbol="BTCUSDT", + symbol="AAPL", quantity=Decimal("0.1"), avg_entry_price=Decimal("50000"), current_price=Decimal("51000"), @@ -158,7 +158,7 @@ class TestTelegramNotifierFormatters: ) mock_send.assert_called_once() msg = mock_send.call_args[0][0] - assert "BTCUSDT" in msg + assert "AAPL" in msg assert "5100.00" in msg assert "100.00" in msg diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index e287777..e0781af 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -1,139 +1,176 @@ -"""Tests for retry with backoff and circuit breaker.""" +"""Tests for shared.resilience module.""" -import time +import asyncio import pytest -from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff +from shared.resilience import CircuitBreaker, async_timeout, retry_async +# --- retry_async tests --- -# --------------------------------------------------------------------------- -# retry_with_backoff tests -# --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_retry_succeeds_first_try(): +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def succeed(): + @retry_async() + async def fn(): nonlocal call_count call_count += 1 return "ok" - result = await succeed() + result = await fn() assert result == "ok" assert call_count == 1 -@pytest.mark.asyncio -async def test_retry_succeeds_after_failures(): +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def flaky(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 if call_count < 3: - raise ValueError("not yet") + raise RuntimeError("transient") return "recovered" - result = await flaky() + result = await fn() assert result == "recovered" assert call_count == 3 -@pytest.mark.asyncio -async def test_retry_raises_after_max_retries(): +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def always_fail(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 - raise RuntimeError("permanent") + raise ValueError("permanent") - with pytest.raises(RuntimeError, match="permanent"): - await always_fail() - # 1 initial + 3 retries = 4 calls + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls assert call_count == 4 -@pytest.mark.asyncio -async def test_retry_respects_max_delay(): - """Backoff should be capped at max_delay.""" +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 - @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) - async def always_fail(): - raise RuntimeError("fail") + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") - start = time.monotonic() - with pytest.raises(RuntimeError): - await always_fail() - elapsed = time.monotonic() - start - # With max_delay=0.02 and 2 retries, total delay should be small - assert elapsed < 0.5 + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 -# --------------------------------------------------------------------------- -# CircuitBreaker tests -# --------------------------------------------------------------------------- +# --- CircuitBreaker tests --- -def test_circuit_starts_closed(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") -def test_circuit_opens_after_threshold(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) for _ in range(3): - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +async def test_half_open_reopens_on_failure(): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def always_fail(): + raise ConnectionError("fail") -def test_circuit_rejects_when_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(always_fail) + # Wait for cooldown + await asyncio.sleep(0.1) -def test_circuit_half_open_after_timeout(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN + # Half-open probe should fail and re-open + with pytest.raises(ConnectionError): + await cb.call(always_fail) - time.sleep(0.06) - assert cb.allow_request() is True - assert cb.state == CircuitState.HALF_OPEN + # Should be open again (no cooldown wait) + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(always_fail) -def test_circuit_closes_on_success(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN +# --- async_timeout tests --- - time.sleep(0.06) - cb.allow_request() # triggers HALF_OPEN - assert cb.state == CircuitState.HALF_OPEN - cb.record_success() - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 -def test_circuit_reopens_on_failure_in_half_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - time.sleep(0.06) - cb.allow_request() # HALF_OPEN - cb.record_failure() - assert cb.state == CircuitState.OPEN +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py index 67c3c82..c9311dd 100644 --- a/shared/tests/test_sa_models.py +++ b/shared/tests/test_sa_models.py @@ -14,6 +14,10 @@ def test_base_metadata_has_all_tables(): "trades", "positions", "portfolio_snapshots", + "news_items", + "symbol_scores", + "market_sentiment", + "stock_selections", } assert expected == table_names @@ -68,6 +72,9 @@ class TestSignalRow: "price", "quantity", "reason", + "conviction", + "stop_loss", + "take_profit", "created_at", } assert expected == cols @@ -120,44 +127,6 @@ class TestOrderRow: assert fk_cols == {"signal_id": "signals.id"} -class TestTradeRow: - def test_table_name(self): - from shared.sa_models import TradeRow - - assert TradeRow.__tablename__ == "trades" - - def test_columns(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - cols = {c.key for c in mapper.column_attrs} - expected = { - "id", - "order_id", - "symbol", - "side", - "price", - "quantity", - "fee", - "traded_at", - } - assert expected == cols - - def test_primary_key(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - pk_cols = [c.name for c in mapper.mapper.primary_key] - assert pk_cols == ["id"] - - def test_order_id_foreign_key(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys} - assert fk_cols == {"order_id": "orders.id"} - - class TestPositionRow: def test_table_name(self): from shared.sa_models import PositionRow @@ -229,11 +198,3 @@ class TestStatusDefault: status_col = table.c.status assert status_col.server_default is not None assert status_col.server_default.arg == "PENDING" - - def test_trade_fee_server_default(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fee_col = table.c.fee - assert fee_col.server_default is not None - assert fee_col.server_default.arg == "0" diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py new file mode 100644 index 0000000..dc2d026 --- /dev/null +++ b/shared/tests/test_sa_news_models.py @@ -0,0 +1,29 @@ +"""Tests for news-related SQLAlchemy models.""" + +from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow + + +def test_news_item_row_tablename(): + assert NewsItemRow.__tablename__ == "news_items" + + +def test_symbol_score_row_tablename(): + assert SymbolScoreRow.__tablename__ == "symbol_scores" + + +def test_market_sentiment_row_tablename(): + assert MarketSentimentRow.__tablename__ == "market_sentiment" + + +def test_stock_selection_row_tablename(): + assert StockSelectionRow.__tablename__ == "stock_selections" + + +def test_news_item_row_columns(): + cols = {c.name for c in NewsItemRow.__table__.columns} + assert cols >= {"id", "source", "headline", "published_at", "sentiment", "category"} + + +def test_symbol_score_row_columns(): + cols = {c.name for c in SymbolScoreRow.__table__.columns} + assert cols >= {"id", "symbol", "news_score", "composite", "updated_at"} diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py deleted file mode 100644 index 9bd8ea3..0000000 --- a/shared/tests/test_sentiment.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Tests for market sentiment module.""" - -from shared.sentiment import SentimentData - - -def test_sentiment_should_buy_default_no_data(): - s = SentimentData() - assert s.should_buy is True - assert s.should_block is False - - -def test_sentiment_should_buy_low_fear_greed(): - s = SentimentData(fear_greed_value=15) - assert s.should_buy is True - - -def test_sentiment_should_not_buy_on_greed(): - s = SentimentData(fear_greed_value=75) - assert s.should_buy is False - - -def test_sentiment_should_not_buy_negative_news(): - s = SentimentData(news_sentiment=-0.4) - assert s.should_buy is False - - -def test_sentiment_should_buy_positive_news(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.3) - assert s.should_buy is True - - -def test_sentiment_should_block_extreme_greed(): - s = SentimentData(fear_greed_value=85) - assert s.should_block is True - - -def test_sentiment_should_block_very_negative_news(): - s = SentimentData(news_sentiment=-0.6) - assert s.should_block is True - - -def test_sentiment_no_block_on_neutral(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.0) - assert s.should_block is False diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py new file mode 100644 index 0000000..9193785 --- /dev/null +++ b/shared/tests/test_sentiment_aggregator.py @@ -0,0 +1,79 @@ +"""Tests for sentiment aggregator.""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from shared.sentiment import SentimentAggregator + + +@pytest.fixture +def aggregator(): + return SentimentAggregator() + + +def test_freshness_decay_recent(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now, now) == 1.0 + + +def test_freshness_decay_3_hours(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7 + + +def test_freshness_decay_12_hours(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3 + + +def test_freshness_decay_old(): + a = SentimentAggregator() + now = datetime.now(UTC) + assert a._freshness_decay(now - timedelta(days=2), now) == 0.0 + + +def test_compute_composite(): + a = SentimentAggregator() + composite = a._compute_composite( + news_score=0.5, social_score=0.3, policy_score=0.8, filing_score=0.2 + ) + expected = 0.5 * 0.3 + 0.3 * 0.2 + 0.8 * 0.3 + 0.2 * 0.2 + assert abs(composite - expected) < 0.001 + + +def test_aggregate_news_by_symbol(aggregator): + now = datetime.now(UTC) + news_items = [ + {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now}, + { + "symbols": ["AAPL"], + "sentiment": 0.3, + "category": "macro", + "published_at": now - timedelta(hours=2), + }, + {"symbols": ["MSFT"], "sentiment": -0.5, "category": "policy", "published_at": now}, + ] + scores = aggregator.aggregate(news_items, now) + assert "AAPL" in scores + assert "MSFT" in scores + assert scores["AAPL"].news_count == 2 + assert scores["AAPL"].news_score > 0 + assert scores["MSFT"].policy_score < 0 + + +def test_aggregate_empty(aggregator): + now = datetime.now(UTC) + assert aggregator.aggregate([], now) == {} + + +def test_determine_regime(): + a = SentimentAggregator() + assert a.determine_regime(15, None) == "risk_off" + assert a.determine_regime(15, 35.0) == "risk_off" + assert a.determine_regime(50, 35.0) == "risk_off" + assert a.determine_regime(70, 15.0) == "risk_on" + assert a.determine_regime(50, 20.0) == "neutral" diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py new file mode 100644 index 0000000..e00ffa6 --- /dev/null +++ b/shared/tests/test_sentiment_models.py @@ -0,0 +1,113 @@ +"""Tests for news and sentiment models.""" + +from datetime import UTC, datetime + +from shared.models import NewsCategory, NewsItem, OrderSide +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore + + +def test_news_item_defaults(): + item = NewsItem( + source="finnhub", + headline="Test headline", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + assert item.id + assert item.symbols == [] + assert item.summary is None + assert item.raw_data == {} + assert item.created_at is not None + + +def test_news_item_with_symbols(): + item = NewsItem( + source="rss", + headline="AAPL earnings beat", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.8, + category=NewsCategory.EARNINGS, + symbols=["AAPL"], + ) + assert item.symbols == ["AAPL"] + assert item.category == NewsCategory.EARNINGS + + +def test_news_category_values(): + assert NewsCategory.POLICY == "policy" + assert NewsCategory.EARNINGS == "earnings" + assert NewsCategory.MACRO == "macro" + assert NewsCategory.SOCIAL == "social" + assert NewsCategory.FILING == "filing" + assert NewsCategory.FED == "fed" + + +def test_symbol_score(): + score = SymbolScore( + symbol="AAPL", + news_score=0.5, + news_count=10, + social_score=0.3, + policy_score=0.0, + filing_score=0.2, + composite=0.3, + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + assert score.symbol == "AAPL" + assert score.composite == 0.3 + + +def test_market_sentiment(): + ms = MarketSentiment( + fear_greed=25, + fear_greed_label="Extreme Fear", + vix=32.5, + fed_stance="hawkish", + market_regime="risk_off", + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + assert ms.market_regime == "risk_off" + assert ms.vix == 32.5 + + +def test_market_sentiment_no_vix(): + ms = MarketSentiment( + fear_greed=50, + fear_greed_label="Neutral", + fed_stance="neutral", + market_regime="neutral", + updated_at=datetime(2026, 4, 2, tzinfo=UTC), + ) + assert ms.vix is None + + +def test_selected_stock(): + ss = SelectedStock( + symbol="NVDA", + side=OrderSide.BUY, + conviction=0.85, + reason="CHIPS Act expansion", + key_news=["Trump signs CHIPS Act expansion"], + ) + assert ss.conviction == 0.85 + assert len(ss.key_news) == 1 + + +def test_candidate(): + c = Candidate( + symbol="TSLA", + source="sentiment", + direction=OrderSide.BUY, + score=0.75, + reason="High social buzz", + ) + assert c.direction == OrderSide.BUY + + c2 = Candidate( + symbol="XOM", + source="llm", + score=0.6, + reason="Oil price surge", + ) + assert c2.direction is None |
