summaryrefslogtreecommitdiff
path: root/services/strategy-engine/tests
diff options
context:
space:
mode:
Diffstat (limited to 'services/strategy-engine/tests')
-rw-r--r--services/strategy-engine/tests/conftest.py5
-rw-r--r--services/strategy-engine/tests/test_base_filters.py7
-rw-r--r--services/strategy-engine/tests/test_bollinger_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_combined_strategy.py11
-rw-r--r--services/strategy-engine/tests/test_ema_crossover_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_engine.py8
-rw-r--r--services/strategy-engine/tests/test_grid_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_indicators.py9
-rw-r--r--services/strategy-engine/tests/test_macd_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_moc_strategy.py7
-rw-r--r--services/strategy-engine/tests/test_multi_symbol.py10
-rw-r--r--services/strategy-engine/tests/test_plugin_loader.py2
-rw-r--r--services/strategy-engine/tests/test_rsi_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_stock_selector.py111
-rw-r--r--services/strategy-engine/tests/test_strategy_validation.py8
-rw-r--r--services/strategy-engine/tests/test_volume_profile_strategy.py15
-rw-r--r--services/strategy-engine/tests/test_vwap_strategy.py12
17 files changed, 174 insertions, 61 deletions
diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py
index eb31b23..2b909ef 100644
--- a/services/strategy-engine/tests/conftest.py
+++ b/services/strategy-engine/tests/conftest.py
@@ -7,3 +7,8 @@ from pathlib import Path
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
if str(STRATEGIES_DIR) not in sys.path:
sys.path.insert(0, str(STRATEGIES_DIR.parent))
+
+# Ensure the worktree's strategy_engine src is preferred over any installed version
+WORKTREE_SRC = Path(__file__).parent.parent / "src"
+if str(WORKTREE_SRC) not in sys.path:
+ sys.path.insert(0, str(WORKTREE_SRC))
diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py
index ae9ca05..66adec7 100644
--- a/services/strategy-engine/tests/test_base_filters.py
+++ b/services/strategy-engine/tests/test_base_filters.py
@@ -5,12 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy
+from shared.models import Candle, OrderSide, Signal
+
class DummyStrategy(BaseStrategy):
name = "dummy"
@@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None):
return Candle(
symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(h)),
low=Decimal(str(lo)),
diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py
index 8261377..70ec66e 100644
--- a/services/strategy-engine/tests/test_bollinger_strategy.py
+++ b/services/strategy-engine/tests/test_bollinger_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Bollinger Bands strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.bollinger_strategy import BollingerStrategy
from shared.models import Candle, OrderSide
-from strategies.bollinger_strategy import BollingerStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py
index 8a4dc74..6a15250 100644
--- a/services/strategy-engine/tests/test_combined_strategy.py
+++ b/services/strategy-engine/tests/test_combined_strategy.py
@@ -5,13 +5,14 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-import pytest
-from shared.models import Candle, Signal, OrderSide
-from strategies.combined_strategy import CombinedStrategy
+import pytest
from strategies.base import BaseStrategy
+from strategies.combined_strategy import CombinedStrategy
+
+from shared.models import Candle, OrderSide, Signal
class AlwaysBuyStrategy(BaseStrategy):
@@ -74,7 +75,7 @@ def _candle(price=100.0):
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price + 10)),
low=Decimal(str(price - 10)),
diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py
index 7028eb0..af2b587 100644
--- a/services/strategy-engine/tests/test_ema_crossover_strategy.py
+++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the EMA Crossover strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from shared.models import Candle, OrderSide
-from strategies.ema_crossover_strategy import EmaCrossoverStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py
index 2623027..fa888b5 100644
--- a/services/strategy-engine/tests/test_engine.py
+++ b/services/strategy-engine/tests/test_engine.py
@@ -1,21 +1,21 @@
"""Tests for the StrategyEngine."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import pytest
+from strategy_engine.engine import StrategyEngine
-from shared.models import Candle, Signal, OrderSide
from shared.events import CandleEvent
-from strategy_engine.engine import StrategyEngine
+from shared.models import Candle, OrderSide, Signal
def make_candle_event() -> dict:
candle = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("50100"),
low=Decimal("49900"),
diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py
index 878b900..f697012 100644
--- a/services/strategy-engine/tests/test_grid_strategy.py
+++ b/services/strategy-engine/tests/test_grid_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Grid strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.grid_strategy import GridStrategy
from shared.models import Candle, OrderSide
-from strategies.grid_strategy import GridStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py
index 481569b..3147fc4 100644
--- a/services/strategy-engine/tests/test_indicators.py
+++ b/services/strategy-engine/tests/test_indicators.py
@@ -5,14 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-import pandas as pd
import numpy as np
+import pandas as pd
import pytest
-
-from strategies.indicators.trend import sma, ema, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
class TestTrend:
diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py
index 556fd4c..7fac16f 100644
--- a/services/strategy-engine/tests/test_macd_strategy.py
+++ b/services/strategy-engine/tests/test_macd_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the MACD strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.macd_strategy import MacdStrategy
from shared.models import Candle, OrderSide
-from strategies.macd_strategy import MacdStrategy
def _candle(price: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price)),
low=Decimal(str(price)),
diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py
index 1928a28..076e846 100644
--- a/services/strategy-engine/tests/test_moc_strategy.py
+++ b/services/strategy-engine/tests/test_moc_strategy.py
@@ -5,19 +5,20 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from shared.models import Candle, OrderSide
from strategies.moc_strategy import MocStrategy
+from shared.models import Candle, OrderSide
+
def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None):
op = open_price if open_price is not None else price - 1 # Default: bullish
return Candle(
symbol="AAPL",
timeframe="5Min",
- open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC),
open=Decimal(str(op)),
high=Decimal(str(price + 1)),
low=Decimal(str(min(op, price) - 1)),
diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py
index 671a9d3..922bfc2 100644
--- a/services/strategy-engine/tests/test_multi_symbol.py
+++ b/services/strategy-engine/tests/test_multi_symbol.py
@@ -9,11 +9,13 @@ import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
+from decimal import Decimal
+
from strategy_engine.engine import StrategyEngine
+
from shared.events import CandleEvent
from shared.models import Candle
-from decimal import Decimal
-from datetime import datetime, timezone
@pytest.mark.asyncio
@@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams():
candle_btc = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
@@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams():
candle_eth = Candle(
symbol="MSFT",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("3000"),
high=Decimal("3100"),
low=Decimal("2900"),
diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py
index 5191fc3..7bd450f 100644
--- a/services/strategy-engine/tests/test_plugin_loader.py
+++ b/services/strategy-engine/tests/test_plugin_loader.py
@@ -2,10 +2,8 @@
from pathlib import Path
-
from strategy_engine.plugin_loader import load_strategies
-
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py
index 6d31fd5..6c74f0b 100644
--- a/services/strategy-engine/tests/test_rsi_strategy.py
+++ b/services/strategy-engine/tests/test_rsi_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the RSI strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.rsi_strategy import RsiStrategy
from shared.models import Candle, OrderSide
-from strategies.rsi_strategy import RsiStrategy
def make_candle(close: float, idx: int = 0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py
new file mode 100644
index 0000000..76b8541
--- /dev/null
+++ b/services/strategy-engine/tests/test_stock_selector.py
@@ -0,0 +1,111 @@
+"""Tests for stock selector engine."""
+
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock
+
+from strategy_engine.stock_selector import (
+ SentimentCandidateSource,
+ StockSelector,
+ _extract_json_array,
+ _parse_llm_selections,
+)
+
+
+async def test_sentiment_candidate_source():
+ mock_db = MagicMock()
+ mock_db.get_top_symbol_scores = AsyncMock(
+ return_value=[
+ {"symbol": "AAPL", "composite": 0.8, "news_count": 5},
+ {"symbol": "NVDA", "composite": 0.6, "news_count": 3},
+ ]
+ )
+
+ source = SentimentCandidateSource(mock_db)
+ candidates = await source.get_candidates()
+
+ assert len(candidates) == 2
+ assert candidates[0].symbol == "AAPL"
+ assert candidates[0].source == "sentiment"
+
+
+def test_parse_llm_selections_valid():
+ llm_response = """
+ [
+ {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]},
+ {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]}
+ ]
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 2
+ assert selections[0].symbol == "NVDA"
+ assert selections[0].conviction == 0.85
+
+
+def test_parse_llm_selections_invalid():
+ selections = _parse_llm_selections("not json")
+ assert selections == []
+
+
+def test_parse_llm_selections_with_markdown():
+ llm_response = """
+ Here are my picks:
+ ```json
+ [
+ {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]}
+ ]
+ ```
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 1
+ assert selections[0].symbol == "TSLA"
+
+
+def test_extract_json_array_from_markdown():
+ text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL", "score": 0.9}]
+
+
+def test_extract_json_array_bare():
+ text = '[{"symbol": "TSLA"}]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "TSLA"}]
+
+
+def test_extract_json_array_invalid():
+ assert _extract_json_array("not json") is None
+
+
+def test_extract_json_array_filters_non_dicts():
+ text = '[{"symbol": "AAPL"}, "bad", 42]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL"}]
+
+
+async def test_selector_close():
+ selector = StockSelector(
+ db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ # No session yet - close should be safe
+ await selector.close()
+ assert selector._http_session is None
+
+
+async def test_selector_blocks_on_risk_off():
+ mock_db = MagicMock()
+ mock_db.get_latest_market_sentiment = AsyncMock(
+ return_value={
+ "fear_greed": 15,
+ "fear_greed_label": "Extreme Fear",
+ "vix": 35.0,
+ "fed_stance": "neutral",
+ "market_regime": "risk_off",
+ "updated_at": datetime.now(UTC),
+ }
+ )
+
+ selector = StockSelector(
+ db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ result = await selector.select()
+ assert result == []
diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py
index debab1f..0d9607a 100644
--- a/services/strategy-engine/tests/test_strategy_validation.py
+++ b/services/strategy-engine/tests/test_strategy_validation.py
@@ -1,13 +1,11 @@
import pytest
-
-from strategies.rsi_strategy import RsiStrategy
-from strategies.macd_strategy import MacdStrategy
from strategies.bollinger_strategy import BollingerStrategy
from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from strategies.grid_strategy import GridStrategy
-from strategies.vwap_strategy import VwapStrategy
+from strategies.macd_strategy import MacdStrategy
+from strategies.rsi_strategy import RsiStrategy
from strategies.volume_profile_strategy import VolumeProfileStrategy
-
+from strategies.vwap_strategy import VwapStrategy
# ── RSI ──────────────────────────────────────────────────────────────────
diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py
index 65ee2e8..f47898c 100644
--- a/services/strategy-engine/tests/test_volume_profile_strategy.py
+++ b/services/strategy-engine/tests/test_volume_profile_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Volume Profile strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.volume_profile_strategy import VolumeProfileStrategy
from shared.models import Candle, OrderSide
-from strategies.volume_profile_strategy import VolumeProfileStrategy
def make_candle(close: float, volume: float = 1.0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
@@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection():
# Create a profile with very high volume at price ~100 and low volume elsewhere
# Prices range from 90 to 110, heavy volume concentrated at 100
- candles_data = []
# Low volume at extremes
- for p in [90, 91, 92, 109, 110]:
- candles_data.append((p, 1.0))
+ candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]]
# Very high volume around 100
- for _ in range(15):
- candles_data.append((100, 100.0))
+ candles_data.extend((100, 100.0) for _ in range(15))
for price, vol in candles_data:
strategy.on_candle(make_candle(price, vol))
@@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection():
# Access the internal method to verify HVN detection
result = strategy._compute_value_area()
assert result is not None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result
# The bin containing price ~100 should have very high volume -> HVN
assert len(hvn_levels) > 0
diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py
index 2c34b01..078d0cf 100644
--- a/services/strategy-engine/tests/test_vwap_strategy.py
+++ b/services/strategy-engine/tests/test_vwap_strategy.py
@@ -1,11 +1,11 @@
"""Tests for the VWAP strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.vwap_strategy import VwapStrategy
from shared.models import Candle, OrderSide
-from strategies.vwap_strategy import VwapStrategy
def make_candle(
@@ -20,7 +20,7 @@ def make_candle(
if low is None:
low = close
if open_time is None:
- open_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ open_time = datetime(2024, 1, 1, tzinfo=UTC)
return Candle(
symbol="AAPL",
timeframe="1m",
@@ -111,11 +111,11 @@ def test_vwap_daily_reset():
"""Candles from two different dates cause VWAP to reset."""
strategy = _configured_strategy()
- day1 = datetime(2024, 1, 1, tzinfo=timezone.utc)
- day2 = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ day1 = datetime(2024, 1, 1, tzinfo=UTC)
+ day2 = datetime(2024, 1, 2, tzinfo=UTC)
# Feed 35 candles on day 1 to build VWAP state
- for i in range(35):
+ for _i in range(35):
strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1))
# Verify state is built up