diff options
Diffstat (limited to 'shared/tests')
| -rw-r--r-- | shared/tests/test_alpaca.py | 4 | ||||
| -rw-r--r-- | shared/tests/test_broker.py | 5 | ||||
| -rw-r--r-- | shared/tests/test_config_validation.py | 29 | ||||
| -rw-r--r-- | shared/tests/test_db.py | 69 | ||||
| -rw-r--r-- | shared/tests/test_db_news.py | 13 | ||||
| -rw-r--r-- | shared/tests/test_events.py | 10 | ||||
| -rw-r--r-- | shared/tests/test_models.py | 20 | ||||
| -rw-r--r-- | shared/tests/test_news_events.py | 6 | ||||
| -rw-r--r-- | shared/tests/test_notifier.py | 2 | ||||
| -rw-r--r-- | shared/tests/test_resilience.py | 203 | ||||
| -rw-r--r-- | shared/tests/test_sa_models.py | 49 | ||||
| -rw-r--r-- | shared/tests/test_sa_news_models.py | 2 | ||||
| -rw-r--r-- | shared/tests/test_sentiment.py | 44 | ||||
| -rw-r--r-- | shared/tests/test_sentiment_aggregator.py | 16 | ||||
| -rw-r--r-- | shared/tests/test_sentiment_models.py | 14 |
15 files changed, 262 insertions, 224 deletions
diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py index 080b7c4..55a2b24 100644 --- a/shared/tests/test_alpaca.py +++ b/shared/tests/test_alpaca.py @@ -1,7 +1,9 @@ """Tests for Alpaca API client.""" -import pytest from unittest.mock import AsyncMock, MagicMock + +import pytest + from shared.alpaca import AlpacaClient diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py index eb1582d..5636611 100644 --- a/shared/tests/test_broker.py +++ b/shared/tests/test_broker.py @@ -1,10 +1,11 @@ """Tests for the Redis broker.""" -import pytest import json -import redis from unittest.mock import AsyncMock, patch +import pytest +import redis + @pytest.mark.asyncio async def test_broker_publish(): diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py new file mode 100644 index 0000000..9376dc6 --- /dev/null +++ b/shared/tests/test_config_validation.py @@ -0,0 +1,29 @@ +"""Tests for config validation.""" + +import pytest +from pydantic import ValidationError + +from shared.config import Settings + + +class TestConfigValidation: + def test_valid_defaults(self): + settings = Settings() + assert settings.risk_max_position_size == 0.1 + + def test_invalid_position_size(self): + with pytest.raises(ValidationError, match="risk_max_position_size"): + Settings(risk_max_position_size=-0.1) + + def test_invalid_health_port(self): + with pytest.raises(ValidationError, match="health_port"): + Settings(health_port=80) + + def test_invalid_log_level(self): + with pytest.raises(ValidationError, match="log_level"): + Settings(log_level="INVALID") + + def test_secret_fields_masked(self): + settings = Settings(alpaca_api_key="my-secret-key") + assert "my-secret-key" not in repr(settings) + assert settings.alpaca_api_key.get_secret_value() == "my-secret-key" diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 239ee64..b44a713 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,10 +1,11 @@ """Tests for the SQLAlchemy async database layer.""" -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch +import pytest + def make_candle(): from shared.models import Candle @@ -12,7 +13,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -22,7 +23,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( id="sig-1", @@ -32,12 +33,12 @@ def make_signal(): price=Decimal("50000"), quantity=Decimal("0.1"), reason="Golden cross", - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) def make_order(): - from shared.models import Order, OrderSide, OrderType, OrderStatus + from shared.models import Order, OrderSide, OrderStatus, OrderType return Order( id="ord-1", @@ -48,7 +49,7 @@ def make_order(): price=Decimal("50000"), quantity=Decimal("0.1"), status=OrderStatus.PENDING, - created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2024, 1, 1, tzinfo=UTC), ) @@ -101,6 +102,54 @@ class TestDatabaseConnect: mock_create.assert_called_once() @pytest.mark.asyncio + async def test_connect_passes_pool_params_for_postgres(self): + from shared.db import Database + + db = Database("postgresql+asyncpg://host/db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800) + mock_create.assert_called_once_with( + "postgresql+asyncpg://host/db", + pool_pre_ping=True, + pool_size=5, + max_overflow=3, + pool_recycle=1800, + ) + + @pytest.mark.asyncio + async def test_connect_skips_pool_params_for_sqlite(self): + from shared.db import Database + + db = Database("sqlite+aiosqlite:///test.db") + + mock_conn = AsyncMock() + mock_cm = AsyncMock() + mock_cm.__aenter__.return_value = mock_conn + + mock_engine = MagicMock() + mock_engine.begin.return_value = mock_cm + mock_engine.dispose = AsyncMock() + + with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create: + with patch("shared.db.async_sessionmaker"): + with patch("shared.db.Base") as mock_base: + mock_base.metadata.create_all = MagicMock() + await db.connect() + mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db") + + @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database @@ -211,7 +260,7 @@ class TestUpdateOrderStatus: db._session_factory = MagicMock(return_value=mock_session) - filled = datetime(2024, 1, 2, tzinfo=timezone.utc) + filled = datetime(2024, 1, 2, tzinfo=UTC) await db.update_order_status("ord-1", OrderStatus.FILLED, filled) mock_session.execute.assert_awaited_once() @@ -230,7 +279,7 @@ class TestGetCandles: mock_row._mapping = { "symbol": "AAPL", "timeframe": "1m", - "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), + "open_time": datetime(2024, 1, 1, tzinfo=UTC), "open": Decimal("50000"), "high": Decimal("51000"), "low": Decimal("49500"), @@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots: mock_row.total_value = Decimal("10000") mock_row.realized_pnl = Decimal("0") mock_row.unrealized_pnl = Decimal("500") - mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc) + mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_row] diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py index a2c9140..c184bed 100644 --- a/shared/tests/test_db_news.py +++ b/shared/tests/test_db_news.py @@ -1,11 +1,12 @@ """Tests for database news/sentiment methods. Uses in-memory SQLite.""" +from datetime import UTC, date, datetime + import pytest -from datetime import datetime, date, timezone from shared.db import Database -from shared.models import NewsItem, NewsCategory -from shared.sentiment_models import SymbolScore, MarketSentiment +from shared.models import NewsCategory, NewsItem +from shared.sentiment_models import MarketSentiment, SymbolScore @pytest.fixture @@ -20,7 +21,7 @@ async def test_insert_and_get_news_items(db): item = NewsItem( source="finnhub", headline="AAPL earnings beat", - published_at=datetime(2026, 4, 2, 12, 0, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC), sentiment=0.8, category=NewsCategory.EARNINGS, symbols=["AAPL"], @@ -40,7 +41,7 @@ async def test_upsert_symbol_score(db): policy_score=0.0, filing_score=0.2, composite=0.3, - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) await db.upsert_symbol_score(score) scores = await db.get_top_symbol_scores(limit=5) @@ -55,7 +56,7 @@ async def test_upsert_market_sentiment(db): vix=18.2, fed_stance="neutral", market_regime="neutral", - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) await db.upsert_market_sentiment(ms) result = await db.get_latest_market_sentiment() diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py index 6077d93..1ccd904 100644 --- a/shared/tests/test_events.py +++ b/shared/tests/test_events.py @@ -1,7 +1,7 @@ """Tests for shared event types.""" +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone def make_candle(): @@ -10,7 +10,7 @@ def make_candle(): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49500"), @@ -20,7 +20,7 @@ def make_candle(): def make_signal(): - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal return Signal( strategy="test", @@ -59,7 +59,7 @@ def test_candle_event_deserialize(): def test_signal_event_serialize(): """Test SignalEvent serializes to dict correctly.""" - from shared.events import SignalEvent, EventType + from shared.events import EventType, SignalEvent signal = make_signal() event = SignalEvent(data=signal) @@ -71,7 +71,7 @@ def test_signal_event_serialize(): def test_event_from_dict_dispatch(): """Test Event.from_dict dispatches to correct class.""" - from shared.events import Event, CandleEvent, SignalEvent + from shared.events import CandleEvent, Event, SignalEvent candle = make_candle() event = CandleEvent(data=candle) diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index 04098ce..40bb791 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -1,8 +1,8 @@ """Tests for shared models and settings.""" import os +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import patch @@ -12,8 +12,11 @@ def test_settings_defaults(): with patch.dict(os.environ, {}, clear=False): settings = Settings() - assert settings.redis_url == "redis://localhost:6379" - assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" + assert settings.redis_url.get_secret_value() == "redis://localhost:6379" + assert ( + settings.database_url.get_secret_value() + == "postgresql://trading:trading@localhost:5432/trading" + ) assert settings.log_level == "INFO" assert settings.risk_max_position_size == 0.1 assert settings.risk_stop_loss_pct == 5.0 @@ -25,7 +28,7 @@ def test_candle_creation(): """Test Candle model creation.""" from shared.models import Candle - now = datetime.now(timezone.utc) + now = datetime.now(UTC) candle = Candle( symbol="AAPL", timeframe="1m", @@ -47,7 +50,7 @@ def test_candle_creation(): def test_signal_creation(): """Test Signal model creation.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi_strategy", @@ -69,9 +72,10 @@ def test_signal_creation(): def test_order_creation(): """Test Order model creation with defaults.""" - from shared.models import Order, OrderSide, OrderType, OrderStatus import uuid + from shared.models import Order, OrderSide, OrderStatus, OrderType + signal_id = str(uuid.uuid4()) order = Order( signal_id=signal_id, @@ -90,7 +94,7 @@ def test_order_creation(): def test_signal_conviction_default(): """Test Signal defaults for conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", @@ -107,7 +111,7 @@ def test_signal_conviction_default(): def test_signal_with_stops(): """Test Signal with explicit conviction, stop_loss, take_profit.""" - from shared.models import Signal, OrderSide + from shared.models import OrderSide, Signal signal = Signal( strategy="rsi", diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py index 384796a..f748d8a 100644 --- a/shared/tests/test_news_events.py +++ b/shared/tests/test_news_events.py @@ -1,16 +1,16 @@ """Tests for NewsEvent.""" -from datetime import datetime, timezone +from datetime import UTC, datetime +from shared.events import Event, EventType, NewsEvent from shared.models import NewsCategory, NewsItem -from shared.events import NewsEvent, EventType, Event def test_news_event_to_dict(): item = NewsItem( source="finnhub", headline="Test", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.5, category=NewsCategory.MACRO, ) diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py index 6c81369..cc98a56 100644 --- a/shared/tests/test_notifier.py +++ b/shared/tests/test_notifier.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position +from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal from shared.notifier import TelegramNotifier diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index e287777..e0781af 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -1,139 +1,176 @@ -"""Tests for retry with backoff and circuit breaker.""" +"""Tests for shared.resilience module.""" -import time +import asyncio import pytest -from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff +from shared.resilience import CircuitBreaker, async_timeout, retry_async +# --- retry_async tests --- -# --------------------------------------------------------------------------- -# retry_with_backoff tests -# --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_retry_succeeds_first_try(): +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def succeed(): + @retry_async() + async def fn(): nonlocal call_count call_count += 1 return "ok" - result = await succeed() + result = await fn() assert result == "ok" assert call_count == 1 -@pytest.mark.asyncio -async def test_retry_succeeds_after_failures(): +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def flaky(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 if call_count < 3: - raise ValueError("not yet") + raise RuntimeError("transient") return "recovered" - result = await flaky() + result = await fn() assert result == "recovered" assert call_count == 3 -@pytest.mark.asyncio -async def test_retry_raises_after_max_retries(): +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def always_fail(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 - raise RuntimeError("permanent") + raise ValueError("permanent") - with pytest.raises(RuntimeError, match="permanent"): - await always_fail() - # 1 initial + 3 retries = 4 calls + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls assert call_count == 4 -@pytest.mark.asyncio -async def test_retry_respects_max_delay(): - """Backoff should be capped at max_delay.""" +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 - @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) - async def always_fail(): - raise RuntimeError("fail") + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") - start = time.monotonic() - with pytest.raises(RuntimeError): - await always_fail() - elapsed = time.monotonic() - start - # With max_delay=0.02 and 2 retries, total delay should be small - assert elapsed < 0.5 + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 -# --------------------------------------------------------------------------- -# CircuitBreaker tests -# --------------------------------------------------------------------------- +# --- CircuitBreaker tests --- -def test_circuit_starts_closed(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") -def test_circuit_opens_after_threshold(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) for _ in range(3): - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +async def test_half_open_reopens_on_failure(): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def always_fail(): + raise ConnectionError("fail") -def test_circuit_rejects_when_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(always_fail) + # Wait for cooldown + await asyncio.sleep(0.1) -def test_circuit_half_open_after_timeout(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN + # Half-open probe should fail and re-open + with pytest.raises(ConnectionError): + await cb.call(always_fail) - time.sleep(0.06) - assert cb.allow_request() is True - assert cb.state == CircuitState.HALF_OPEN + # Should be open again (no cooldown wait) + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(always_fail) -def test_circuit_closes_on_success(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN +# --- async_timeout tests --- - time.sleep(0.06) - cb.allow_request() # triggers HALF_OPEN - assert cb.state == CircuitState.HALF_OPEN - cb.record_success() - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 -def test_circuit_reopens_on_failure_in_half_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - time.sleep(0.06) - cb.allow_request() # HALF_OPEN - cb.record_failure() - assert cb.state == CircuitState.OPEN +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py index dc6355e..c9311dd 100644 --- a/shared/tests/test_sa_models.py +++ b/shared/tests/test_sa_models.py @@ -72,6 +72,9 @@ class TestSignalRow: "price", "quantity", "reason", + "conviction", + "stop_loss", + "take_profit", "created_at", } assert expected == cols @@ -124,44 +127,6 @@ class TestOrderRow: assert fk_cols == {"signal_id": "signals.id"} -class TestTradeRow: - def test_table_name(self): - from shared.sa_models import TradeRow - - assert TradeRow.__tablename__ == "trades" - - def test_columns(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - cols = {c.key for c in mapper.column_attrs} - expected = { - "id", - "order_id", - "symbol", - "side", - "price", - "quantity", - "fee", - "traded_at", - } - assert expected == cols - - def test_primary_key(self): - from shared.sa_models import TradeRow - - mapper = inspect(TradeRow) - pk_cols = [c.name for c in mapper.mapper.primary_key] - assert pk_cols == ["id"] - - def test_order_id_foreign_key(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys} - assert fk_cols == {"order_id": "orders.id"} - - class TestPositionRow: def test_table_name(self): from shared.sa_models import PositionRow @@ -233,11 +198,3 @@ class TestStatusDefault: status_col = table.c.status assert status_col.server_default is not None assert status_col.server_default.arg == "PENDING" - - def test_trade_fee_server_default(self): - from shared.sa_models import TradeRow - - table = TradeRow.__table__ - fee_col = table.c.fee - assert fee_col.server_default is not None - assert fee_col.server_default.arg == "0" diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py index 91e6d4a..dc2d026 100644 --- a/shared/tests/test_sa_news_models.py +++ b/shared/tests/test_sa_news_models.py @@ -1,6 +1,6 @@ """Tests for news-related SQLAlchemy models.""" -from shared.sa_models import NewsItemRow, SymbolScoreRow, MarketSentimentRow, StockSelectionRow +from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow def test_news_item_row_tablename(): diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py deleted file mode 100644 index 9bd8ea3..0000000 --- a/shared/tests/test_sentiment.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Tests for market sentiment module.""" - -from shared.sentiment import SentimentData - - -def test_sentiment_should_buy_default_no_data(): - s = SentimentData() - assert s.should_buy is True - assert s.should_block is False - - -def test_sentiment_should_buy_low_fear_greed(): - s = SentimentData(fear_greed_value=15) - assert s.should_buy is True - - -def test_sentiment_should_not_buy_on_greed(): - s = SentimentData(fear_greed_value=75) - assert s.should_buy is False - - -def test_sentiment_should_not_buy_negative_news(): - s = SentimentData(news_sentiment=-0.4) - assert s.should_buy is False - - -def test_sentiment_should_buy_positive_news(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.3) - assert s.should_buy is True - - -def test_sentiment_should_block_extreme_greed(): - s = SentimentData(fear_greed_value=85) - assert s.should_block is True - - -def test_sentiment_should_block_very_negative_news(): - s = SentimentData(news_sentiment=-0.6) - assert s.should_block is True - - -def test_sentiment_no_block_on_neutral(): - s = SentimentData(fear_greed_value=50, news_sentiment=0.0) - assert s.should_block is False diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py index a99c711..9193785 100644 --- a/shared/tests/test_sentiment_aggregator.py +++ b/shared/tests/test_sentiment_aggregator.py @@ -1,7 +1,9 @@ """Tests for sentiment aggregator.""" +from datetime import UTC, datetime, timedelta + import pytest -from datetime import datetime, timezone, timedelta + from shared.sentiment import SentimentAggregator @@ -12,25 +14,25 @@ def aggregator(): def test_freshness_decay_recent(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now, now) == 1.0 def test_freshness_decay_3_hours(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7 def test_freshness_decay_12_hours(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3 def test_freshness_decay_old(): a = SentimentAggregator() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert a._freshness_decay(now - timedelta(days=2), now) == 0.0 @@ -44,7 +46,7 @@ def test_compute_composite(): def test_aggregate_news_by_symbol(aggregator): - now = datetime.now(timezone.utc) + now = datetime.now(UTC) news_items = [ {"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now}, { @@ -64,7 +66,7 @@ def test_aggregate_news_by_symbol(aggregator): def test_aggregate_empty(aggregator): - now = datetime.now(timezone.utc) + now = datetime.now(UTC) assert aggregator.aggregate([], now) == {} diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py index 25fc371..e00ffa6 100644 --- a/shared/tests/test_sentiment_models.py +++ b/shared/tests/test_sentiment_models.py @@ -1,16 +1,16 @@ """Tests for news and sentiment models.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from shared.models import NewsCategory, NewsItem, OrderSide -from shared.sentiment_models import SymbolScore, MarketSentiment, SelectedStock, Candidate +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore def test_news_item_defaults(): item = NewsItem( source="finnhub", headline="Test headline", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.5, category=NewsCategory.MACRO, ) @@ -25,7 +25,7 @@ def test_news_item_with_symbols(): item = NewsItem( source="rss", headline="AAPL earnings beat", - published_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + published_at=datetime(2026, 4, 2, tzinfo=UTC), sentiment=0.8, category=NewsCategory.EARNINGS, symbols=["AAPL"], @@ -52,7 +52,7 @@ def test_symbol_score(): policy_score=0.0, filing_score=0.2, composite=0.3, - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) assert score.symbol == "AAPL" assert score.composite == 0.3 @@ -65,7 +65,7 @@ def test_market_sentiment(): vix=32.5, fed_stance="hawkish", market_regime="risk_off", - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) assert ms.market_regime == "risk_off" assert ms.vix == 32.5 @@ -77,7 +77,7 @@ def test_market_sentiment_no_vix(): fear_greed_label="Neutral", fed_stance="neutral", market_regime="neutral", - updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc), + updated_at=datetime(2026, 4, 2, tzinfo=UTC), ) assert ms.vix is None |
