summaryrefslogtreecommitdiff
path: root/shared/tests
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests')
-rw-r--r--shared/tests/test_alpaca.py4
-rw-r--r--shared/tests/test_broker.py5
-rw-r--r--shared/tests/test_config_validation.py29
-rw-r--r--shared/tests/test_db.py69
-rw-r--r--shared/tests/test_db_news.py13
-rw-r--r--shared/tests/test_events.py10
-rw-r--r--shared/tests/test_models.py20
-rw-r--r--shared/tests/test_news_events.py6
-rw-r--r--shared/tests/test_notifier.py2
-rw-r--r--shared/tests/test_resilience.py203
-rw-r--r--shared/tests/test_sa_models.py49
-rw-r--r--shared/tests/test_sa_news_models.py2
-rw-r--r--shared/tests/test_sentiment.py44
-rw-r--r--shared/tests/test_sentiment_aggregator.py16
-rw-r--r--shared/tests/test_sentiment_models.py14
15 files changed, 262 insertions, 224 deletions
diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py
index 080b7c4..55a2b24 100644
--- a/shared/tests/test_alpaca.py
+++ b/shared/tests/test_alpaca.py
@@ -1,7 +1,9 @@
"""Tests for Alpaca API client."""
-import pytest
from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
from shared.alpaca import AlpacaClient
diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py
index eb1582d..5636611 100644
--- a/shared/tests/test_broker.py
+++ b/shared/tests/test_broker.py
@@ -1,10 +1,11 @@
"""Tests for the Redis broker."""
-import pytest
import json
-import redis
from unittest.mock import AsyncMock, patch
+import pytest
+import redis
+
@pytest.mark.asyncio
async def test_broker_publish():
diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py
new file mode 100644
index 0000000..9376dc6
--- /dev/null
+++ b/shared/tests/test_config_validation.py
@@ -0,0 +1,29 @@
+"""Tests for config validation."""
+
+import pytest
+from pydantic import ValidationError
+
+from shared.config import Settings
+
+
+class TestConfigValidation:
+ def test_valid_defaults(self):
+ settings = Settings()
+ assert settings.risk_max_position_size == 0.1
+
+ def test_invalid_position_size(self):
+ with pytest.raises(ValidationError, match="risk_max_position_size"):
+ Settings(risk_max_position_size=-0.1)
+
+ def test_invalid_health_port(self):
+ with pytest.raises(ValidationError, match="health_port"):
+ Settings(health_port=80)
+
+ def test_invalid_log_level(self):
+ with pytest.raises(ValidationError, match="log_level"):
+ Settings(log_level="INVALID")
+
+ def test_secret_fields_masked(self):
+ settings = Settings(alpaca_api_key="my-secret-key")
+ assert "my-secret-key" not in repr(settings)
+ assert settings.alpaca_api_key.get_secret_value() == "my-secret-key"
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 239ee64..b44a713 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,10 +1,11 @@
"""Tests for the SQLAlchemy async database layer."""
-import pytest
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
+import pytest
+
def make_candle():
from shared.models import Candle
@@ -12,7 +13,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -22,7 +23,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
id="sig-1",
@@ -32,12 +33,12 @@ def make_signal():
price=Decimal("50000"),
quantity=Decimal("0.1"),
reason="Golden cross",
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def make_order():
- from shared.models import Order, OrderSide, OrderType, OrderStatus
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
return Order(
id="ord-1",
@@ -48,7 +49,7 @@ def make_order():
price=Decimal("50000"),
quantity=Decimal("0.1"),
status=OrderStatus.PENDING,
- created_at=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
@@ -101,6 +102,54 @@ class TestDatabaseConnect:
mock_create.assert_called_once()
@pytest.mark.asyncio
+ async def test_connect_passes_pool_params_for_postgres(self):
+ from shared.db import Database
+
+ db = Database("postgresql+asyncpg://host/db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect(pool_size=5, max_overflow=3, pool_recycle=1800)
+ mock_create.assert_called_once_with(
+ "postgresql+asyncpg://host/db",
+ pool_pre_ping=True,
+ pool_size=5,
+ max_overflow=3,
+ pool_recycle=1800,
+ )
+
+ @pytest.mark.asyncio
+ async def test_connect_skips_pool_params_for_sqlite(self):
+ from shared.db import Database
+
+ db = Database("sqlite+aiosqlite:///test.db")
+
+ mock_conn = AsyncMock()
+ mock_cm = AsyncMock()
+ mock_cm.__aenter__.return_value = mock_conn
+
+ mock_engine = MagicMock()
+ mock_engine.begin.return_value = mock_cm
+ mock_engine.dispose = AsyncMock()
+
+ with patch("shared.db.create_async_engine", return_value=mock_engine) as mock_create:
+ with patch("shared.db.async_sessionmaker"):
+ with patch("shared.db.Base") as mock_base:
+ mock_base.metadata.create_all = MagicMock()
+ await db.connect()
+ mock_create.assert_called_once_with("sqlite+aiosqlite:///test.db")
+
+ @pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database
@@ -211,7 +260,7 @@ class TestUpdateOrderStatus:
db._session_factory = MagicMock(return_value=mock_session)
- filled = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ filled = datetime(2024, 1, 2, tzinfo=UTC)
await db.update_order_status("ord-1", OrderStatus.FILLED, filled)
mock_session.execute.assert_awaited_once()
@@ -230,7 +279,7 @@ class TestGetCandles:
mock_row._mapping = {
"symbol": "AAPL",
"timeframe": "1m",
- "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc),
+ "open_time": datetime(2024, 1, 1, tzinfo=UTC),
"open": Decimal("50000"),
"high": Decimal("51000"),
"low": Decimal("49500"),
@@ -396,7 +445,7 @@ class TestGetPortfolioSnapshots:
mock_row.total_value = Decimal("10000")
mock_row.realized_pnl = Decimal("0")
mock_row.unrealized_pnl = Decimal("500")
- mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ mock_row.snapshot_at = datetime(2024, 1, 1, tzinfo=UTC)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_row]
diff --git a/shared/tests/test_db_news.py b/shared/tests/test_db_news.py
index a2c9140..c184bed 100644
--- a/shared/tests/test_db_news.py
+++ b/shared/tests/test_db_news.py
@@ -1,11 +1,12 @@
"""Tests for database news/sentiment methods. Uses in-memory SQLite."""
+from datetime import UTC, date, datetime
+
import pytest
-from datetime import datetime, date, timezone
from shared.db import Database
-from shared.models import NewsItem, NewsCategory
-from shared.sentiment_models import SymbolScore, MarketSentiment
+from shared.models import NewsCategory, NewsItem
+from shared.sentiment_models import MarketSentiment, SymbolScore
@pytest.fixture
@@ -20,7 +21,7 @@ async def test_insert_and_get_news_items(db):
item = NewsItem(
source="finnhub",
headline="AAPL earnings beat",
- published_at=datetime(2026, 4, 2, 12, 0, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, 12, 0, tzinfo=UTC),
sentiment=0.8,
category=NewsCategory.EARNINGS,
symbols=["AAPL"],
@@ -40,7 +41,7 @@ async def test_upsert_symbol_score(db):
policy_score=0.0,
filing_score=0.2,
composite=0.3,
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
await db.upsert_symbol_score(score)
scores = await db.get_top_symbol_scores(limit=5)
@@ -55,7 +56,7 @@ async def test_upsert_market_sentiment(db):
vix=18.2,
fed_stance="neutral",
market_regime="neutral",
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
await db.upsert_market_sentiment(ms)
result = await db.get_latest_market_sentiment()
diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py
index 6077d93..1ccd904 100644
--- a/shared/tests/test_events.py
+++ b/shared/tests/test_events.py
@@ -1,7 +1,7 @@
"""Tests for shared event types."""
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
def make_candle():
@@ -10,7 +10,7 @@ def make_candle():
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49500"),
@@ -20,7 +20,7 @@ def make_candle():
def make_signal():
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
return Signal(
strategy="test",
@@ -59,7 +59,7 @@ def test_candle_event_deserialize():
def test_signal_event_serialize():
"""Test SignalEvent serializes to dict correctly."""
- from shared.events import SignalEvent, EventType
+ from shared.events import EventType, SignalEvent
signal = make_signal()
event = SignalEvent(data=signal)
@@ -71,7 +71,7 @@ def test_signal_event_serialize():
def test_event_from_dict_dispatch():
"""Test Event.from_dict dispatches to correct class."""
- from shared.events import Event, CandleEvent, SignalEvent
+ from shared.events import CandleEvent, Event, SignalEvent
candle = make_candle()
event = CandleEvent(data=candle)
diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py
index 04098ce..40bb791 100644
--- a/shared/tests/test_models.py
+++ b/shared/tests/test_models.py
@@ -1,8 +1,8 @@
"""Tests for shared models and settings."""
import os
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import patch
@@ -12,8 +12,11 @@ def test_settings_defaults():
with patch.dict(os.environ, {}, clear=False):
settings = Settings()
- assert settings.redis_url == "redis://localhost:6379"
- assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading"
+ assert settings.redis_url.get_secret_value() == "redis://localhost:6379"
+ assert (
+ settings.database_url.get_secret_value()
+ == "postgresql://trading:trading@localhost:5432/trading"
+ )
assert settings.log_level == "INFO"
assert settings.risk_max_position_size == 0.1
assert settings.risk_stop_loss_pct == 5.0
@@ -25,7 +28,7 @@ def test_candle_creation():
"""Test Candle model creation."""
from shared.models import Candle
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
candle = Candle(
symbol="AAPL",
timeframe="1m",
@@ -47,7 +50,7 @@ def test_candle_creation():
def test_signal_creation():
"""Test Signal model creation."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi_strategy",
@@ -69,9 +72,10 @@ def test_signal_creation():
def test_order_creation():
"""Test Order model creation with defaults."""
- from shared.models import Order, OrderSide, OrderType, OrderStatus
import uuid
+ from shared.models import Order, OrderSide, OrderStatus, OrderType
+
signal_id = str(uuid.uuid4())
order = Order(
signal_id=signal_id,
@@ -90,7 +94,7 @@ def test_order_creation():
def test_signal_conviction_default():
"""Test Signal defaults for conviction, stop_loss, take_profit."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi",
@@ -107,7 +111,7 @@ def test_signal_conviction_default():
def test_signal_with_stops():
"""Test Signal with explicit conviction, stop_loss, take_profit."""
- from shared.models import Signal, OrderSide
+ from shared.models import OrderSide, Signal
signal = Signal(
strategy="rsi",
diff --git a/shared/tests/test_news_events.py b/shared/tests/test_news_events.py
index 384796a..f748d8a 100644
--- a/shared/tests/test_news_events.py
+++ b/shared/tests/test_news_events.py
@@ -1,16 +1,16 @@
"""Tests for NewsEvent."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
+from shared.events import Event, EventType, NewsEvent
from shared.models import NewsCategory, NewsItem
-from shared.events import NewsEvent, EventType, Event
def test_news_event_to_dict():
item = NewsItem(
source="finnhub",
headline="Test",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.5,
category=NewsCategory.MACRO,
)
diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py
index 6c81369..cc98a56 100644
--- a/shared/tests/test_notifier.py
+++ b/shared/tests/test_notifier.py
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position
+from shared.models import Order, OrderSide, OrderStatus, OrderType, Position, Signal
from shared.notifier import TelegramNotifier
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
index e287777..e0781af 100644
--- a/shared/tests/test_resilience.py
+++ b/shared/tests/test_resilience.py
@@ -1,139 +1,176 @@
-"""Tests for retry with backoff and circuit breaker."""
+"""Tests for shared.resilience module."""
-import time
+import asyncio
import pytest
-from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff
+from shared.resilience import CircuitBreaker, async_timeout, retry_async
+# --- retry_async tests ---
-# ---------------------------------------------------------------------------
-# retry_with_backoff tests
-# ---------------------------------------------------------------------------
-
-@pytest.mark.asyncio
-async def test_retry_succeeds_first_try():
+async def test_succeeds_without_retry():
+ """Function succeeds first try, called once."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def succeed():
+ @retry_async()
+ async def fn():
nonlocal call_count
call_count += 1
return "ok"
- result = await succeed()
+ result = await fn()
assert result == "ok"
assert call_count == 1
-@pytest.mark.asyncio
-async def test_retry_succeeds_after_failures():
+async def test_retries_on_failure_then_succeeds():
+ """Fails twice then succeeds, verify call count."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def flaky():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
if call_count < 3:
- raise ValueError("not yet")
+ raise RuntimeError("transient")
return "recovered"
- result = await flaky()
+ result = await fn()
assert result == "recovered"
assert call_count == 3
-@pytest.mark.asyncio
-async def test_retry_raises_after_max_retries():
+async def test_raises_after_max_retries():
+ """Always fails, raises after max retries."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def always_fail():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
- raise RuntimeError("permanent")
+ raise ValueError("permanent")
- with pytest.raises(RuntimeError, match="permanent"):
- await always_fail()
- # 1 initial + 3 retries = 4 calls
+ with pytest.raises(ValueError, match="permanent"):
+ await fn()
+
+ # 1 initial + 3 retries = 4 total calls
assert call_count == 4
-@pytest.mark.asyncio
-async def test_retry_respects_max_delay():
- """Backoff should be capped at max_delay."""
+async def test_no_retry_on_excluded_exception():
+ """Excluded exception raises immediately, call count = 1."""
+ call_count = 0
- @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
- async def always_fail():
- raise RuntimeError("fail")
+ @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,))
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ raise TypeError("excluded")
- start = time.monotonic()
- with pytest.raises(RuntimeError):
- await always_fail()
- elapsed = time.monotonic() - start
- # With max_delay=0.02 and 2 retries, total delay should be small
- assert elapsed < 0.5
+ with pytest.raises(TypeError, match="excluded"):
+ await fn()
+
+ assert call_count == 1
-# ---------------------------------------------------------------------------
-# CircuitBreaker tests
-# ---------------------------------------------------------------------------
+# --- CircuitBreaker tests ---
-def test_circuit_starts_closed():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05)
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_closed_allows_calls():
+ """CircuitBreaker in closed state passes through."""
+ cb = CircuitBreaker(failure_threshold=5, cooldown=60.0)
+ async def fn():
+ return "ok"
+
+ result = await cb.call(fn)
+ assert result == "ok"
+
+
+async def test_opens_after_threshold():
+ """After N failures, raises RuntimeError."""
+ cb = CircuitBreaker(failure_threshold=3, cooldown=60.0)
+
+ async def fail():
+ raise RuntimeError("fail")
-def test_circuit_opens_after_threshold():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0)
for _ in range(3):
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+ # Now the breaker should be open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+
+async def test_half_open_after_cooldown():
+ """After cooldown, allows recovery attempt."""
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def fail():
+ raise RuntimeError("fail")
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+
+ # Breaker is open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.06)
+
+ # Now should allow a call (half_open). Succeed to close it.
+ async def succeed():
+ return "recovered"
+
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+ # Breaker should be closed again
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+
+async def test_half_open_reopens_on_failure():
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def always_fail():
+ raise ConnectionError("fail")
-def test_circuit_rejects_when_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
+ # Wait for cooldown
+ await asyncio.sleep(0.1)
-def test_circuit_half_open_after_timeout():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+ # Half-open probe should fail and re-open
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
- time.sleep(0.06)
- assert cb.allow_request() is True
- assert cb.state == CircuitState.HALF_OPEN
+ # Should be open again (no cooldown wait)
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(always_fail)
-def test_circuit_closes_on_success():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+# --- async_timeout tests ---
- time.sleep(0.06)
- cb.allow_request() # triggers HALF_OPEN
- assert cb.state == CircuitState.HALF_OPEN
- cb.record_success()
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_completes_within_timeout():
+ """async_timeout doesn't interfere with fast operations."""
+ async with async_timeout(1.0):
+ await asyncio.sleep(0.01)
+ result = 42
+ assert result == 42
-def test_circuit_reopens_on_failure_in_half_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- time.sleep(0.06)
- cb.allow_request() # HALF_OPEN
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+async def test_raises_on_timeout():
+ """async_timeout raises TimeoutError for slow operations."""
+ with pytest.raises(TimeoutError):
+ async with async_timeout(0.05):
+ await asyncio.sleep(1.0)
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py
index dc6355e..c9311dd 100644
--- a/shared/tests/test_sa_models.py
+++ b/shared/tests/test_sa_models.py
@@ -72,6 +72,9 @@ class TestSignalRow:
"price",
"quantity",
"reason",
+ "conviction",
+ "stop_loss",
+ "take_profit",
"created_at",
}
assert expected == cols
@@ -124,44 +127,6 @@ class TestOrderRow:
assert fk_cols == {"signal_id": "signals.id"}
-class TestTradeRow:
- def test_table_name(self):
- from shared.sa_models import TradeRow
-
- assert TradeRow.__tablename__ == "trades"
-
- def test_columns(self):
- from shared.sa_models import TradeRow
-
- mapper = inspect(TradeRow)
- cols = {c.key for c in mapper.column_attrs}
- expected = {
- "id",
- "order_id",
- "symbol",
- "side",
- "price",
- "quantity",
- "fee",
- "traded_at",
- }
- assert expected == cols
-
- def test_primary_key(self):
- from shared.sa_models import TradeRow
-
- mapper = inspect(TradeRow)
- pk_cols = [c.name for c in mapper.mapper.primary_key]
- assert pk_cols == ["id"]
-
- def test_order_id_foreign_key(self):
- from shared.sa_models import TradeRow
-
- table = TradeRow.__table__
- fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys}
- assert fk_cols == {"order_id": "orders.id"}
-
-
class TestPositionRow:
def test_table_name(self):
from shared.sa_models import PositionRow
@@ -233,11 +198,3 @@ class TestStatusDefault:
status_col = table.c.status
assert status_col.server_default is not None
assert status_col.server_default.arg == "PENDING"
-
- def test_trade_fee_server_default(self):
- from shared.sa_models import TradeRow
-
- table = TradeRow.__table__
- fee_col = table.c.fee
- assert fee_col.server_default is not None
- assert fee_col.server_default.arg == "0"
diff --git a/shared/tests/test_sa_news_models.py b/shared/tests/test_sa_news_models.py
index 91e6d4a..dc2d026 100644
--- a/shared/tests/test_sa_news_models.py
+++ b/shared/tests/test_sa_news_models.py
@@ -1,6 +1,6 @@
"""Tests for news-related SQLAlchemy models."""
-from shared.sa_models import NewsItemRow, SymbolScoreRow, MarketSentimentRow, StockSelectionRow
+from shared.sa_models import MarketSentimentRow, NewsItemRow, StockSelectionRow, SymbolScoreRow
def test_news_item_row_tablename():
diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py
deleted file mode 100644
index 9bd8ea3..0000000
--- a/shared/tests/test_sentiment.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Tests for market sentiment module."""
-
-from shared.sentiment import SentimentData
-
-
-def test_sentiment_should_buy_default_no_data():
- s = SentimentData()
- assert s.should_buy is True
- assert s.should_block is False
-
-
-def test_sentiment_should_buy_low_fear_greed():
- s = SentimentData(fear_greed_value=15)
- assert s.should_buy is True
-
-
-def test_sentiment_should_not_buy_on_greed():
- s = SentimentData(fear_greed_value=75)
- assert s.should_buy is False
-
-
-def test_sentiment_should_not_buy_negative_news():
- s = SentimentData(news_sentiment=-0.4)
- assert s.should_buy is False
-
-
-def test_sentiment_should_buy_positive_news():
- s = SentimentData(fear_greed_value=50, news_sentiment=0.3)
- assert s.should_buy is True
-
-
-def test_sentiment_should_block_extreme_greed():
- s = SentimentData(fear_greed_value=85)
- assert s.should_block is True
-
-
-def test_sentiment_should_block_very_negative_news():
- s = SentimentData(news_sentiment=-0.6)
- assert s.should_block is True
-
-
-def test_sentiment_no_block_on_neutral():
- s = SentimentData(fear_greed_value=50, news_sentiment=0.0)
- assert s.should_block is False
diff --git a/shared/tests/test_sentiment_aggregator.py b/shared/tests/test_sentiment_aggregator.py
index a99c711..9193785 100644
--- a/shared/tests/test_sentiment_aggregator.py
+++ b/shared/tests/test_sentiment_aggregator.py
@@ -1,7 +1,9 @@
"""Tests for sentiment aggregator."""
+from datetime import UTC, datetime, timedelta
+
import pytest
-from datetime import datetime, timezone, timedelta
+
from shared.sentiment import SentimentAggregator
@@ -12,25 +14,25 @@ def aggregator():
def test_freshness_decay_recent():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now, now) == 1.0
def test_freshness_decay_3_hours():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now - timedelta(hours=3), now) == 0.7
def test_freshness_decay_12_hours():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now - timedelta(hours=12), now) == 0.3
def test_freshness_decay_old():
a = SentimentAggregator()
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert a._freshness_decay(now - timedelta(days=2), now) == 0.0
@@ -44,7 +46,7 @@ def test_compute_composite():
def test_aggregate_news_by_symbol(aggregator):
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
news_items = [
{"symbols": ["AAPL"], "sentiment": 0.8, "category": "earnings", "published_at": now},
{
@@ -64,7 +66,7 @@ def test_aggregate_news_by_symbol(aggregator):
def test_aggregate_empty(aggregator):
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
assert aggregator.aggregate([], now) == {}
diff --git a/shared/tests/test_sentiment_models.py b/shared/tests/test_sentiment_models.py
index 25fc371..e00ffa6 100644
--- a/shared/tests/test_sentiment_models.py
+++ b/shared/tests/test_sentiment_models.py
@@ -1,16 +1,16 @@
"""Tests for news and sentiment models."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from shared.models import NewsCategory, NewsItem, OrderSide
-from shared.sentiment_models import SymbolScore, MarketSentiment, SelectedStock, Candidate
+from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock, SymbolScore
def test_news_item_defaults():
item = NewsItem(
source="finnhub",
headline="Test headline",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.5,
category=NewsCategory.MACRO,
)
@@ -25,7 +25,7 @@ def test_news_item_with_symbols():
item = NewsItem(
source="rss",
headline="AAPL earnings beat",
- published_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ published_at=datetime(2026, 4, 2, tzinfo=UTC),
sentiment=0.8,
category=NewsCategory.EARNINGS,
symbols=["AAPL"],
@@ -52,7 +52,7 @@ def test_symbol_score():
policy_score=0.0,
filing_score=0.2,
composite=0.3,
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
assert score.symbol == "AAPL"
assert score.composite == 0.3
@@ -65,7 +65,7 @@ def test_market_sentiment():
vix=32.5,
fed_stance="hawkish",
market_regime="risk_off",
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
assert ms.market_regime == "risk_off"
assert ms.vix == 32.5
@@ -77,7 +77,7 @@ def test_market_sentiment_no_vix():
fear_greed_label="Neutral",
fed_stance="neutral",
market_regime="neutral",
- updated_at=datetime(2026, 4, 2, tzinfo=timezone.utc),
+ updated_at=datetime(2026, 4, 2, tzinfo=UTC),
)
assert ms.vix is None