summaryrefslogtreecommitdiff
path: root/services/strategy-engine
diff options
context:
space:
mode:
Diffstat (limited to 'services/strategy-engine')
-rw-r--r--services/strategy-engine/Dockerfile9
-rw-r--r--services/strategy-engine/pyproject.toml6
-rw-r--r--services/strategy-engine/src/strategy_engine/config.py2
-rw-r--r--services/strategy-engine/src/strategy_engine/engine.py8
-rw-r--r--services/strategy-engine/src/strategy_engine/main.py85
-rw-r--r--services/strategy-engine/src/strategy_engine/plugin_loader.py1
-rw-r--r--services/strategy-engine/src/strategy_engine/stock_selector.py418
-rw-r--r--services/strategy-engine/strategies/base.py5
-rw-r--r--services/strategy-engine/strategies/bollinger_strategy.py2
-rw-r--r--services/strategy-engine/strategies/combined_strategy.py2
-rw-r--r--services/strategy-engine/strategies/ema_crossover_strategy.py2
-rw-r--r--services/strategy-engine/strategies/grid_strategy.py5
-rw-r--r--services/strategy-engine/strategies/indicators/__init__.py16
-rw-r--r--services/strategy-engine/strategies/indicators/momentum.py2
-rw-r--r--services/strategy-engine/strategies/indicators/trend.py2
-rw-r--r--services/strategy-engine/strategies/indicators/volatility.py2
-rw-r--r--services/strategy-engine/strategies/indicators/volume.py2
-rw-r--r--services/strategy-engine/strategies/macd_strategy.py2
-rw-r--r--services/strategy-engine/strategies/moc_strategy.py4
-rw-r--r--services/strategy-engine/strategies/rsi_strategy.py2
-rw-r--r--services/strategy-engine/strategies/volume_profile_strategy.py4
-rw-r--r--services/strategy-engine/strategies/vwap_strategy.py4
-rw-r--r--services/strategy-engine/tests/conftest.py5
-rw-r--r--services/strategy-engine/tests/test_base_filters.py7
-rw-r--r--services/strategy-engine/tests/test_bollinger_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_combined_strategy.py11
-rw-r--r--services/strategy-engine/tests/test_ema_crossover_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_engine.py8
-rw-r--r--services/strategy-engine/tests/test_grid_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_indicators.py9
-rw-r--r--services/strategy-engine/tests/test_macd_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_moc_strategy.py7
-rw-r--r--services/strategy-engine/tests/test_multi_symbol.py10
-rw-r--r--services/strategy-engine/tests/test_plugin_loader.py2
-rw-r--r--services/strategy-engine/tests/test_rsi_strategy.py6
-rw-r--r--services/strategy-engine/tests/test_stock_selector.py111
-rw-r--r--services/strategy-engine/tests/test_strategy_validation.py8
-rw-r--r--services/strategy-engine/tests/test_volume_profile_strategy.py15
-rw-r--r--services/strategy-engine/tests/test_vwap_strategy.py12
39 files changed, 713 insertions, 107 deletions
diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile
index de635dc..f1484e9 100644
--- a/services/strategy-engine/Dockerfile
+++ b/services/strategy-engine/Dockerfile
@@ -1,9 +1,16 @@
-FROM python:3.12-slim
+FROM python:3.12-slim AS builder
WORKDIR /app
COPY shared/ shared/
RUN pip install --no-cache-dir ./shared
COPY services/strategy-engine/ services/strategy-engine/
RUN pip install --no-cache-dir ./services/strategy-engine
+
+FROM python:3.12-slim
+RUN useradd -r -s /bin/false appuser
+WORKDIR /app
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
COPY services/strategy-engine/strategies/ /app/strategies/
ENV PYTHONPATH=/app
+USER appuser
CMD ["python", "-m", "strategy_engine.main"]
diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml
index 4f5b6be..e4bfb12 100644
--- a/services/strategy-engine/pyproject.toml
+++ b/services/strategy-engine/pyproject.toml
@@ -3,11 +3,7 @@ name = "strategy-engine"
version = "0.1.0"
description = "Plugin-based strategy execution engine"
requires-python = ">=3.12"
-dependencies = [
- "pandas>=2.0",
- "numpy>=1.20",
- "trading-shared",
-]
+dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py
index e3a49c2..9fd9c49 100644
--- a/services/strategy-engine/src/strategy_engine/config.py
+++ b/services/strategy-engine/src/strategy_engine/config.py
@@ -4,6 +4,6 @@ from shared.config import Settings
class StrategyConfig(Settings):
- symbols: list[str] = ["BTC/USDT"]
+ symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]
timeframes: list[str] = ["1m"]
strategy_params: dict = {}
diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py
index d401aee..4b2c468 100644
--- a/services/strategy-engine/src/strategy_engine/engine.py
+++ b/services/strategy-engine/src/strategy_engine/engine.py
@@ -2,11 +2,11 @@
import logging
-from shared.broker import RedisBroker
-from shared.events import CandleEvent, SignalEvent, Event
-
from strategies.base import BaseStrategy
+from shared.broker import RedisBroker
+from shared.events import CandleEvent, Event, SignalEvent
+
logger = logging.getLogger(__name__)
@@ -26,7 +26,7 @@ class StrategyEngine:
try:
event = Event.from_dict(raw)
except Exception as exc:
- logger.warning("Failed to parse event: %s – %s", raw, exc)
+ logger.warning("Failed to parse event: %s - %s", raw, exc)
continue
if not isinstance(event, CandleEvent):
diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py
index 30de528..3d73058 100644
--- a/services/strategy-engine/src/strategy_engine/main.py
+++ b/services/strategy-engine/src/strategy_engine/main.py
@@ -1,17 +1,25 @@
"""Strategy Engine Service entry point."""
import asyncio
+import zoneinfo
+from datetime import datetime
from pathlib import Path
+import aiohttp
+
+from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
+from shared.db import Database
from shared.healthcheck import HealthCheckServer
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
-
+from shared.sentiment_models import MarketSentiment
+from shared.shutdown import GracefulShutdown
from strategy_engine.config import StrategyConfig
from strategy_engine.engine import StrategyEngine
from strategy_engine.plugin_loader import load_strategies
+from strategy_engine.stock_selector import StockSelector
# The strategies directory lives alongside the installed package
STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies"
@@ -30,23 +38,74 @@ async def process_symbol(engine: StrategyEngine, stream: str, log) -> None:
last_id = await engine.process_once(stream, last_id)
+async def run_stock_selector(
+ selector: StockSelector,
+ notifier: TelegramNotifier,
+ db: Database,
+ config: StrategyConfig,
+ log,
+) -> None:
+ """Run the stock selector once per day at the configured time."""
+ et = zoneinfo.ZoneInfo("America/New_York")
+
+ while True:
+ now_et = datetime.now(et)
+ target_hour, target_min = map(int, config.selector_final_time.split(":"))
+
+ if now_et.hour == target_hour and now_et.minute == target_min:
+ log.info("stock_selector_running")
+ try:
+ selections = await selector.select()
+ if selections:
+ ms_data = await db.get_latest_market_sentiment()
+ ms = None
+ if ms_data:
+ ms = MarketSentiment(**ms_data)
+ await notifier.send_stock_selection(selections, ms)
+ log.info("stock_selector_complete", picks=[s.symbol for s in selections])
+ else:
+ log.info("stock_selector_no_picks")
+ except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc:
+ log.warning("stock_selector_network_error", error=str(exc))
+ except (ValueError, KeyError, TypeError) as exc:
+ log.warning("stock_selector_data_error", error=str(exc))
+ except Exception as exc:
+ log.error("stock_selector_error", error=str(exc), exc_info=True)
+ await asyncio.sleep(120) # Sleep past this minute
+ else:
+ await asyncio.sleep(30)
+
+
async def run() -> None:
config = StrategyConfig()
log = setup_logging("strategy-engine", config.log_level, config.log_format)
metrics = ServiceMetrics("strategy_engine")
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token,
+ bot_token=config.telegram_bot_token.get_secret_value(),
chat_id=config.telegram_chat_id,
)
- broker = RedisBroker(config.redis_url)
+ broker = RedisBroker(config.redis_url.get_secret_value())
+
+ db = Database(config.database_url.get_secret_value())
+ await db.connect()
+
+ alpaca = AlpacaClient(
+ api_key=config.alpaca_api_key.get_secret_value(),
+ api_secret=config.alpaca_api_secret.get_secret_value(),
+ paper=config.alpaca_paper,
+ )
+
strategies = load_strategies(STRATEGIES_DIR)
for strategy in strategies:
params = config.strategy_params.get(strategy.name, {})
strategy.configure(params)
+ shutdown = GracefulShutdown()
+ shutdown.install_handlers()
+
log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies])
engine = StrategyEngine(broker=broker, strategies=strategies)
@@ -67,9 +126,23 @@ async def run() -> None:
task = asyncio.create_task(process_symbol(engine, stream, log))
tasks.append(task)
- await asyncio.gather(*tasks)
+ if config.anthropic_api_key.get_secret_value():
+ selector = StockSelector(
+ db=db,
+ broker=broker,
+ alpaca=alpaca,
+ anthropic_api_key=config.anthropic_api_key.get_secret_value(),
+ anthropic_model=config.anthropic_model,
+ max_picks=config.selector_max_picks,
+ )
+ tasks.append(
+ asyncio.create_task(run_stock_selector(selector, notifier, db, config, log))
+ )
+ log.info("stock_selector_enabled", time=config.selector_final_time)
+
+ await shutdown.wait()
except Exception as exc:
- log.error("fatal_error", error=str(exc))
+ log.error("fatal_error", error=str(exc), exc_info=True)
await notifier.send_error(str(exc), "strategy-engine")
raise
finally:
@@ -78,6 +151,8 @@ async def run() -> None:
metrics.service_up.labels(service="strategy-engine").set(0)
await notifier.close()
await broker.close()
+ await alpaca.close()
+ await db.close()
def main() -> None:
diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py
index 62e4160..57680db 100644
--- a/services/strategy-engine/src/strategy_engine/plugin_loader.py
+++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py
@@ -5,7 +5,6 @@ import sys
from pathlib import Path
import yaml
-
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py
new file mode 100644
index 0000000..8657b93
--- /dev/null
+++ b/services/strategy-engine/src/strategy_engine/stock_selector.py
@@ -0,0 +1,418 @@
+"""3-stage stock selector engine: sentiment → technical → LLM."""
+
+import asyncio
+import json
+import logging
+import re
+from datetime import UTC, datetime
+
+import aiohttp
+
+from shared.alpaca import AlpacaClient
+from shared.broker import RedisBroker
+from shared.db import Database
+from shared.models import OrderSide
+from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock
+
+logger = logging.getLogger(__name__)
+
+ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
+
+
+def _extract_json_array(text: str) -> list[dict] | None:
+ """Extract a JSON array from text that may contain markdown code blocks."""
+ code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
+ if code_block:
+ raw = code_block.group(1)
+ else:
+ array_match = re.search(r"\[.*\]", text, re.DOTALL)
+ if array_match:
+ raw = array_match.group(0)
+ else:
+ raw = text.strip()
+
+ try:
+ data = json.loads(raw)
+ if isinstance(data, list):
+ return [item for item in data if isinstance(item, dict)]
+ return None
+ except (json.JSONDecodeError, TypeError):
+ return None
+
+
+def _parse_llm_selections(text: str) -> list[SelectedStock]:
+ """Parse LLM response into SelectedStock list.
+
+ Handles both bare JSON arrays and markdown code blocks.
+ Returns empty list on any parse error.
+ """
+ items = _extract_json_array(text)
+ if items is None:
+ return []
+
+ selections = []
+ for item in items:
+ try:
+ selection = SelectedStock(
+ symbol=item["symbol"],
+ side=OrderSide(item["side"]),
+ conviction=float(item["conviction"]),
+ reason=item.get("reason", ""),
+ key_news=item.get("key_news", []),
+ )
+ selections.append(selection)
+ except (KeyError, ValueError) as e:
+ logger.warning("Skipping invalid selection item: %s", e)
+ return selections
+
+
+class SentimentCandidateSource:
+ """Generates candidates from DB sentiment scores."""
+
+ def __init__(self, db: Database) -> None:
+ self._db = db
+
+ async def get_candidates(self) -> list[Candidate]:
+ rows = await self._db.get_top_symbol_scores(limit=20)
+ candidates = []
+ for row in rows:
+ composite = float(row.get("composite", 0))
+ if composite == 0:
+ continue
+ candidates.append(
+ Candidate(
+ symbol=row["symbol"],
+ source="sentiment",
+ score=composite,
+ reason=f"composite={composite:.2f}, news_count={row.get('news_count', 0)}",
+ )
+ )
+ return candidates
+
+
+class LLMCandidateSource:
+ """Generates candidates by asking Claude to analyze recent news."""
+
+ def __init__(self, db: Database, api_key: str, model: str) -> None:
+ self._db = db
+ self._api_key = api_key
+ self._model = model
+
+ async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]:
+ news_items = await self._db.get_recent_news(hours=24)
+ if not news_items:
+ return []
+
+ headlines = []
+ for item in news_items[:50]: # cap at 50 to stay within context
+ symbols = item.get("symbols", [])
+ sym_str = ", ".join(symbols) if symbols else "N/A"
+ headlines.append(f"[{sym_str}] {item['headline']}")
+
+ prompt = (
+ "You are a stock analyst. Given recent news headlines, identify the 5-10 most "
+ "actionable US stock tickers. Return ONLY a JSON array with objects having: "
+ "symbol (ticker), direction ('BUY' or 'SELL'), score (0-1), reason (brief).\n\n"
+ "Headlines:\n" + "\n".join(headlines)
+ )
+
+ own_session = session is None
+ if own_session:
+ session = aiohttp.ClientSession()
+
+ try:
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM candidate source error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
+
+ content = data.get("content", [])
+ text = ""
+ for block in content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ text += block.get("text", "")
+
+ return self._parse_candidates(text)
+ except Exception as e:
+ logger.error("LLMCandidateSource error: %s", e)
+ return []
+ finally:
+ if own_session:
+ await session.close()
+
+ def _parse_candidates(self, text: str) -> list[Candidate]:
+ items = _extract_json_array(text)
+ if items is None:
+ return []
+
+ candidates = []
+ for item in items:
+ try:
+ direction_str = item.get("direction", "BUY")
+ direction = OrderSide(direction_str)
+ except ValueError:
+ direction = None
+ candidates.append(
+ Candidate(
+ symbol=item["symbol"],
+ source="llm",
+ direction=direction,
+ score=float(item.get("score", 0.5)),
+ reason=item.get("reason", ""),
+ )
+ )
+ return candidates
+
+
+def _compute_rsi(closes: list[float], period: int = 14) -> float:
+ """Compute RSI for the last data point."""
+ if len(closes) < period + 1:
+ return 50.0 # neutral if insufficient data
+
+ deltas = [closes[i] - closes[i - 1] for i in range(1, len(closes))]
+ gains = [d if d > 0 else 0.0 for d in deltas]
+ losses = [-d if d < 0 else 0.0 for d in deltas]
+
+ avg_gain = sum(gains[:period]) / period
+ avg_loss = sum(losses[:period]) / period
+
+ for i in range(period, len(deltas)):
+ avg_gain = (avg_gain * (period - 1) + gains[i]) / period
+ avg_loss = (avg_loss * (period - 1) + losses[i]) / period
+
+ if avg_loss == 0:
+ return 100.0
+ rs = avg_gain / avg_loss
+ return 100.0 - (100.0 / (1.0 + rs))
+
+
+class StockSelector:
+ """Orchestrates the 3-stage stock selection pipeline."""
+
+ def __init__(
+ self,
+ db: Database,
+ broker: RedisBroker,
+ alpaca: AlpacaClient,
+ anthropic_api_key: str,
+ anthropic_model: str = "claude-sonnet-4-20250514",
+ max_picks: int = 3,
+ ) -> None:
+ self._db = db
+ self._broker = broker
+ self._alpaca = alpaca
+ self._api_key = anthropic_api_key
+ self._model = anthropic_model
+ self._max_picks = max_picks
+ self._http_session: aiohttp.ClientSession | None = None
+ self._session_lock = asyncio.Lock()
+
+ async def _ensure_session(self) -> aiohttp.ClientSession:
+ async with self._session_lock:
+ if self._http_session is None or self._http_session.closed:
+ self._http_session = aiohttp.ClientSession()
+ return self._http_session
+
+ async def close(self) -> None:
+ if self._http_session and not self._http_session.closed:
+ await self._http_session.close()
+
+ async def select(self) -> list[SelectedStock]:
+ """Run the full 3-stage pipeline and return selected stocks."""
+ # Market gate: check sentiment
+ sentiment_data = await self._db.get_latest_market_sentiment()
+ if sentiment_data is None:
+ logger.warning("No market sentiment data; skipping selection")
+ return []
+
+ market_sentiment = MarketSentiment(**sentiment_data)
+ if market_sentiment.market_regime == "risk_off":
+ logger.info("Market is risk_off; skipping stock selection")
+ return []
+
+ # Stage 1: gather candidates from both sources
+ sentiment_source = SentimentCandidateSource(self._db)
+ llm_source = LLMCandidateSource(self._db, self._api_key, self._model)
+
+ session = await self._ensure_session()
+ sentiment_candidates = await sentiment_source.get_candidates()
+ llm_candidates = await llm_source.get_candidates(session=session)
+
+ candidates = self._merge_candidates(sentiment_candidates, llm_candidates)
+ if not candidates:
+ logger.info("No candidates found")
+ return []
+
+ # Stage 2: technical filter
+ filtered = await self._technical_filter(candidates)
+ if not filtered:
+ logger.info("All candidates filtered out by technical criteria")
+ return []
+
+ # Stage 3: LLM final selection
+ selections = await self._llm_final_select(filtered, market_sentiment)
+
+ # Persist and publish
+ today = datetime.now(UTC).date()
+ sentiment_snapshot = {
+ "fear_greed": market_sentiment.fear_greed,
+ "market_regime": market_sentiment.market_regime,
+ "vix": market_sentiment.vix,
+ }
+ for stock in selections:
+ try:
+ await self._db.insert_stock_selection(
+ trade_date=today,
+ symbol=stock.symbol,
+ side=stock.side.value,
+ conviction=stock.conviction,
+ reason=stock.reason,
+ key_news=stock.key_news,
+ sentiment_snapshot=sentiment_snapshot,
+ )
+ except Exception as e:
+ logger.error("Failed to persist selection for %s: %s", stock.symbol, e)
+
+ try:
+ await self._broker.publish(
+ "selected_stocks",
+ {
+ "symbol": stock.symbol,
+ "side": stock.side.value,
+ "conviction": stock.conviction,
+ "reason": stock.reason,
+ "key_news": stock.key_news,
+ "trade_date": str(today),
+ },
+ )
+ except Exception as e:
+ logger.error("Failed to publish selection for %s: %s", stock.symbol, e)
+
+ return selections
+
+ def _merge_candidates(
+ self, sentiment: list[Candidate], llm: list[Candidate]
+ ) -> list[Candidate]:
+ """Deduplicate candidates by symbol, keeping the higher score."""
+ by_symbol: dict[str, Candidate] = {}
+ for c in sentiment + llm:
+ existing = by_symbol.get(c.symbol)
+ if existing is None or c.score > existing.score:
+ by_symbol[c.symbol] = c
+ return sorted(by_symbol.values(), key=lambda c: c.score, reverse=True)
+
+ async def _technical_filter(self, candidates: list[Candidate]) -> list[Candidate]:
+ """Filter candidates using RSI, EMA20, and volume criteria."""
+ passed = []
+ for candidate in candidates:
+ try:
+ bars = await self._alpaca.get_bars(candidate.symbol, timeframe="1Day", limit=60)
+ if len(bars) < 21:
+ logger.debug("Insufficient bars for %s", candidate.symbol)
+ continue
+
+ closes = [float(b["c"]) for b in bars]
+ volumes = [float(b["v"]) for b in bars]
+
+ rsi = _compute_rsi(closes)
+ if not (30 <= rsi <= 70):
+ logger.debug("%s RSI=%.1f outside 30-70", candidate.symbol, rsi)
+ continue
+
+ ema20 = sum(closes[-20:]) / 20 # simple approximation
+ current_price = closes[-1]
+ if current_price <= ema20:
+ logger.debug(
+ "%s price %.2f <= EMA20 %.2f", candidate.symbol, current_price, ema20
+ )
+ continue
+
+ avg_volume = sum(volumes[:-1]) / max(len(volumes) - 1, 1)
+ current_volume = volumes[-1]
+ if current_volume <= 0.5 * avg_volume:
+ logger.debug(
+ "%s volume %.0f <= 50%% avg %.0f",
+ candidate.symbol,
+ current_volume,
+ avg_volume,
+ )
+ continue
+
+ passed.append(candidate)
+ except Exception as e:
+ logger.warning("Technical filter error for %s: %s", candidate.symbol, e)
+
+ return passed
+
+ async def _llm_final_select(
+ self, candidates: list[Candidate], market_sentiment: MarketSentiment
+ ) -> list[SelectedStock]:
+ """Ask Claude to pick 2-3 stocks with rationale."""
+ candidate_lines = [
+ f"- {c.symbol} (source={c.source}, score={c.score:.2f}, reason={c.reason})"
+ for c in candidates
+ ]
+ market_context = (
+ f"Fear/Greed: {market_sentiment.fear_greed} ({market_sentiment.fear_greed_label}), "
+ f"VIX: {market_sentiment.vix}, "
+ f"Fed stance: {market_sentiment.fed_stance}, "
+ f"Regime: {market_sentiment.market_regime}"
+ )
+
+ prompt = (
+ f"You are a portfolio manager. Select 2-3 stocks for today's session.\n\n"
+ f"Market context: {market_context}\n\n"
+ f"Candidates (already passed technical filters):\n"
+ + "\n".join(candidate_lines)
+ + "\n\n"
+ "Return ONLY a JSON array with objects having:\n"
+ " symbol, side ('BUY' or 'SELL'), conviction (0-1), reason (1-2 sentences), "
+ "key_news (list of 1-3 relevant headlines or facts)\n"
+ f"Select at most {self._max_picks} stocks."
+ )
+
+ try:
+ session = await self._ensure_session()
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM final select error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
+
+ content = data.get("content", [])
+ text = ""
+ for block in content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ text += block.get("text", "")
+
+ return _parse_llm_selections(text)[: self._max_picks]
+ except Exception as e:
+ logger.error("LLM final select error: %s", e)
+ return []
diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py
index d5be675..1d9d289 100644
--- a/services/strategy-engine/strategies/base.py
+++ b/services/strategy-engine/strategies/base.py
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from collections import deque
from decimal import Decimal
-from typing import Optional
import pandas as pd
@@ -102,7 +101,7 @@ class BaseStrategy(ABC):
def _calculate_atr_stops(
self, entry_price: Decimal, side: str
- ) -> tuple[Optional[Decimal], Optional[Decimal]]:
+ ) -> tuple[Decimal | None, Decimal | None]:
"""Calculate ATR-based stop-loss and take-profit.
Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data.
@@ -131,7 +130,7 @@ class BaseStrategy(ABC):
return sl, tp
- def _apply_filters(self, signal: Signal) -> Optional[Signal]:
+ def _apply_filters(self, signal: Signal) -> Signal | None:
"""Apply all filters to a signal. Returns signal with SL/TP or None if filtered out."""
if signal is None:
return None
diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py
index ebe7967..02ff09a 100644
--- a/services/strategy-engine/strategies/bollinger_strategy.py
+++ b/services/strategy-engine/strategies/bollinger_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py
index ba92485..f562918 100644
--- a/services/strategy-engine/strategies/combined_strategy.py
+++ b/services/strategy-engine/strategies/combined_strategy.py
@@ -2,7 +2,7 @@
from decimal import Decimal
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py
index 68d0ba3..9c181f3 100644
--- a/services/strategy-engine/strategies/ema_crossover_strategy.py
+++ b/services/strategy-engine/strategies/ema_crossover_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py
index 283bfe5..491252e 100644
--- a/services/strategy-engine/strategies/grid_strategy.py
+++ b/services/strategy-engine/strategies/grid_strategy.py
@@ -1,9 +1,8 @@
from decimal import Decimal
-from typing import Optional
import numpy as np
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -17,7 +16,7 @@ class GridStrategy(BaseStrategy):
self._grid_count: int = 5
self._quantity: Decimal = Decimal("0.01")
self._grid_levels: list[float] = []
- self._last_zone: Optional[int] = None
+ self._last_zone: int | None = None
self._exit_threshold_pct: float = 5.0
self._out_of_range: bool = False
self._in_position: bool = False # Track if we have any grid positions
diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py
index 3c713e6..01637b7 100644
--- a/services/strategy-engine/strategies/indicators/__init__.py
+++ b/services/strategy-engine/strategies/indicators/__init__.py
@@ -1,21 +1,21 @@
"""Reusable technical indicator functions."""
-from strategies.indicators.trend import ema, sma, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
__all__ = [
- "ema",
- "sma",
- "macd",
"adx",
"atr",
"bollinger_bands",
+ "ema",
"keltner_channels",
+ "macd",
+ "obv",
"rsi",
+ "sma",
"stochastic",
- "volume_sma",
"volume_ratio",
- "obv",
+ "volume_sma",
]
diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py
index c479452..a82210b 100644
--- a/services/strategy-engine/strategies/indicators/momentum.py
+++ b/services/strategy-engine/strategies/indicators/momentum.py
@@ -1,7 +1,7 @@
"""Momentum indicators: RSI, Stochastic."""
-import pandas as pd
import numpy as np
+import pandas as pd
def rsi(closes: pd.Series, period: int = 14) -> pd.Series:
diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py
index c94a071..1085199 100644
--- a/services/strategy-engine/strategies/indicators/trend.py
+++ b/services/strategy-engine/strategies/indicators/trend.py
@@ -1,7 +1,7 @@
"""Trend indicators: EMA, SMA, MACD, ADX."""
-import pandas as pd
import numpy as np
+import pandas as pd
def sma(series: pd.Series, period: int) -> pd.Series:
diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py
index c16143e..da82f26 100644
--- a/services/strategy-engine/strategies/indicators/volatility.py
+++ b/services/strategy-engine/strategies/indicators/volatility.py
@@ -1,7 +1,7 @@
"""Volatility indicators: ATR, Bollinger Bands, Keltner Channels."""
-import pandas as pd
import numpy as np
+import pandas as pd
def atr(
diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py
index 502f1ce..d7c6471 100644
--- a/services/strategy-engine/strategies/indicators/volume.py
+++ b/services/strategy-engine/strategies/indicators/volume.py
@@ -1,7 +1,7 @@
"""Volume indicators: Volume SMA, Volume Ratio, OBV."""
-import pandas as pd
import numpy as np
+import pandas as pd
def volume_sma(volumes: pd.Series, period: int = 20) -> pd.Series:
diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py
index 356a42b..b5aea07 100644
--- a/services/strategy-engine/strategies/macd_strategy.py
+++ b/services/strategy-engine/strategies/macd_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py
index 7eaa59e..cbc8440 100644
--- a/services/strategy-engine/strategies/moc_strategy.py
+++ b/services/strategy-engine/strategies/moc_strategy.py
@@ -8,12 +8,12 @@ Rules:
"""
from collections import deque
-from decimal import Decimal
from datetime import datetime
+from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py
index 0646d8c..2df080d 100644
--- a/services/strategy-engine/strategies/rsi_strategy.py
+++ b/services/strategy-engine/strategies/rsi_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import pandas as pd
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py
index ef2ae14..67b5c23 100644
--- a/services/strategy-engine/strategies/volume_profile_strategy.py
+++ b/services/strategy-engine/strategies/volume_profile_strategy.py
@@ -3,7 +3,7 @@ from decimal import Decimal
import numpy as np
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -137,7 +137,7 @@ class VolumeProfileStrategy(BaseStrategy):
if result is None:
return None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ poc, va_low, va_high, hvn_levels, _lvn_levels = result
if close < va_low:
self._was_below_va = True
diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py
index d64950e..4ee4952 100644
--- a/services/strategy-engine/strategies/vwap_strategy.py
+++ b/services/strategy-engine/strategies/vwap_strategy.py
@@ -1,7 +1,7 @@
from collections import deque
from decimal import Decimal
-from shared.models import Candle, Signal, OrderSide
+from shared.models import Candle, OrderSide, Signal
from strategies.base import BaseStrategy
@@ -107,7 +107,7 @@ class VwapStrategy(BaseStrategy):
# Standard deviation of (TP - VWAP) for bands
std_dev = 0.0
if len(self._tp_values) >= 2:
- diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)]
+ diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values, strict=True)]
mean_diff = sum(diffs) / len(diffs)
variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs)
std_dev = variance**0.5
diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py
index eb31b23..2b909ef 100644
--- a/services/strategy-engine/tests/conftest.py
+++ b/services/strategy-engine/tests/conftest.py
@@ -7,3 +7,8 @@ from pathlib import Path
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
if str(STRATEGIES_DIR) not in sys.path:
sys.path.insert(0, str(STRATEGIES_DIR.parent))
+
+# Ensure the worktree's strategy_engine src is preferred over any installed version
+WORKTREE_SRC = Path(__file__).parent.parent / "src"
+if str(WORKTREE_SRC) not in sys.path:
+ sys.path.insert(0, str(WORKTREE_SRC))
diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py
index ae9ca05..66adec7 100644
--- a/services/strategy-engine/tests/test_base_filters.py
+++ b/services/strategy-engine/tests/test_base_filters.py
@@ -5,12 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy
+from shared.models import Candle, OrderSide, Signal
+
class DummyStrategy(BaseStrategy):
name = "dummy"
@@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None):
return Candle(
symbol="AAPL",
timeframe="1h",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(h)),
low=Decimal(str(lo)),
diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py
index 8261377..70ec66e 100644
--- a/services/strategy-engine/tests/test_bollinger_strategy.py
+++ b/services/strategy-engine/tests/test_bollinger_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Bollinger Bands strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.bollinger_strategy import BollingerStrategy
from shared.models import Candle, OrderSide
-from strategies.bollinger_strategy import BollingerStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py
index 8a4dc74..6a15250 100644
--- a/services/strategy-engine/tests/test_combined_strategy.py
+++ b/services/strategy-engine/tests/test_combined_strategy.py
@@ -5,13 +5,14 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
from decimal import Decimal
-from datetime import datetime, timezone
-import pytest
-from shared.models import Candle, Signal, OrderSide
-from strategies.combined_strategy import CombinedStrategy
+import pytest
from strategies.base import BaseStrategy
+from strategies.combined_strategy import CombinedStrategy
+
+from shared.models import Candle, OrderSide, Signal
class AlwaysBuyStrategy(BaseStrategy):
@@ -74,7 +75,7 @@ def _candle(price=100.0):
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price + 10)),
low=Decimal(str(price - 10)),
diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py
index 7028eb0..af2b587 100644
--- a/services/strategy-engine/tests/test_ema_crossover_strategy.py
+++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the EMA Crossover strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from shared.models import Candle, OrderSide
-from strategies.ema_crossover_strategy import EmaCrossoverStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py
index 2623027..fa888b5 100644
--- a/services/strategy-engine/tests/test_engine.py
+++ b/services/strategy-engine/tests/test_engine.py
@@ -1,21 +1,21 @@
"""Tests for the StrategyEngine."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock
import pytest
+from strategy_engine.engine import StrategyEngine
-from shared.models import Candle, Signal, OrderSide
from shared.events import CandleEvent
-from strategy_engine.engine import StrategyEngine
+from shared.models import Candle, OrderSide, Signal
def make_candle_event() -> dict:
candle = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("50100"),
low=Decimal("49900"),
diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py
index 878b900..f697012 100644
--- a/services/strategy-engine/tests/test_grid_strategy.py
+++ b/services/strategy-engine/tests/test_grid_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Grid strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.grid_strategy import GridStrategy
from shared.models import Candle, OrderSide
-from strategies.grid_strategy import GridStrategy
def make_candle(close: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py
index 481569b..3147fc4 100644
--- a/services/strategy-engine/tests/test_indicators.py
+++ b/services/strategy-engine/tests/test_indicators.py
@@ -5,14 +5,13 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-import pandas as pd
import numpy as np
+import pandas as pd
import pytest
-
-from strategies.indicators.trend import sma, ema, macd, adx
-from strategies.indicators.volatility import atr, bollinger_bands
from strategies.indicators.momentum import rsi, stochastic
-from strategies.indicators.volume import volume_sma, volume_ratio, obv
+from strategies.indicators.trend import adx, ema, macd, sma
+from strategies.indicators.volatility import atr, bollinger_bands
+from strategies.indicators.volume import obv, volume_ratio, volume_sma
class TestTrend:
diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py
index 556fd4c..7fac16f 100644
--- a/services/strategy-engine/tests/test_macd_strategy.py
+++ b/services/strategy-engine/tests/test_macd_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the MACD strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.macd_strategy import MacdStrategy
from shared.models import Candle, OrderSide
-from strategies.macd_strategy import MacdStrategy
def _candle(price: float) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(price)),
high=Decimal(str(price)),
low=Decimal(str(price)),
diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py
index 1928a28..076e846 100644
--- a/services/strategy-engine/tests/test_moc_strategy.py
+++ b/services/strategy-engine/tests/test_moc_strategy.py
@@ -5,19 +5,20 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
-from shared.models import Candle, OrderSide
from strategies.moc_strategy import MocStrategy
+from shared.models import Candle, OrderSide
+
def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None):
op = open_price if open_price is not None else price - 1 # Default: bullish
return Candle(
symbol="AAPL",
timeframe="5Min",
- open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC),
open=Decimal(str(op)),
high=Decimal(str(price + 1)),
low=Decimal(str(min(op, price) - 1)),
diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py
index 671a9d3..922bfc2 100644
--- a/services/strategy-engine/tests/test_multi_symbol.py
+++ b/services/strategy-engine/tests/test_multi_symbol.py
@@ -9,11 +9,13 @@ import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+from datetime import UTC, datetime
+from decimal import Decimal
+
from strategy_engine.engine import StrategyEngine
+
from shared.events import CandleEvent
from shared.models import Candle
-from decimal import Decimal
-from datetime import datetime, timezone
@pytest.mark.asyncio
@@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams():
candle_btc = Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
@@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams():
candle_eth = Candle(
symbol="MSFT",
timeframe="1m",
- open_time=datetime(2025, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2025, 1, 1, tzinfo=UTC),
open=Decimal("3000"),
high=Decimal("3100"),
low=Decimal("2900"),
diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py
index 5191fc3..7bd450f 100644
--- a/services/strategy-engine/tests/test_plugin_loader.py
+++ b/services/strategy-engine/tests/test_plugin_loader.py
@@ -2,10 +2,8 @@
from pathlib import Path
-
from strategy_engine.plugin_loader import load_strategies
-
STRATEGIES_DIR = Path(__file__).parent.parent / "strategies"
diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py
index 6d31fd5..6c74f0b 100644
--- a/services/strategy-engine/tests/test_rsi_strategy.py
+++ b/services/strategy-engine/tests/test_rsi_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the RSI strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.rsi_strategy import RsiStrategy
from shared.models import Candle, OrderSide
-from strategies.rsi_strategy import RsiStrategy
def make_candle(close: float, idx: int = 0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py
new file mode 100644
index 0000000..76b8541
--- /dev/null
+++ b/services/strategy-engine/tests/test_stock_selector.py
@@ -0,0 +1,111 @@
+"""Tests for stock selector engine."""
+
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock
+
+from strategy_engine.stock_selector import (
+ SentimentCandidateSource,
+ StockSelector,
+ _extract_json_array,
+ _parse_llm_selections,
+)
+
+
+async def test_sentiment_candidate_source():
+ mock_db = MagicMock()
+ mock_db.get_top_symbol_scores = AsyncMock(
+ return_value=[
+ {"symbol": "AAPL", "composite": 0.8, "news_count": 5},
+ {"symbol": "NVDA", "composite": 0.6, "news_count": 3},
+ ]
+ )
+
+ source = SentimentCandidateSource(mock_db)
+ candidates = await source.get_candidates()
+
+ assert len(candidates) == 2
+ assert candidates[0].symbol == "AAPL"
+ assert candidates[0].source == "sentiment"
+
+
+def test_parse_llm_selections_valid():
+ llm_response = """
+ [
+ {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]},
+ {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]}
+ ]
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 2
+ assert selections[0].symbol == "NVDA"
+ assert selections[0].conviction == 0.85
+
+
+def test_parse_llm_selections_invalid():
+ selections = _parse_llm_selections("not json")
+ assert selections == []
+
+
+def test_parse_llm_selections_with_markdown():
+ llm_response = """
+ Here are my picks:
+ ```json
+ [
+ {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]}
+ ]
+ ```
+ """
+ selections = _parse_llm_selections(llm_response)
+ assert len(selections) == 1
+ assert selections[0].symbol == "TSLA"
+
+
+def test_extract_json_array_from_markdown():
+ text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL", "score": 0.9}]
+
+
+def test_extract_json_array_bare():
+ text = '[{"symbol": "TSLA"}]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "TSLA"}]
+
+
+def test_extract_json_array_invalid():
+ assert _extract_json_array("not json") is None
+
+
+def test_extract_json_array_filters_non_dicts():
+ text = '[{"symbol": "AAPL"}, "bad", 42]'
+ result = _extract_json_array(text)
+ assert result == [{"symbol": "AAPL"}]
+
+
+async def test_selector_close():
+ selector = StockSelector(
+ db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ # No session yet - close should be safe
+ await selector.close()
+ assert selector._http_session is None
+
+
+async def test_selector_blocks_on_risk_off():
+ mock_db = MagicMock()
+ mock_db.get_latest_market_sentiment = AsyncMock(
+ return_value={
+ "fear_greed": 15,
+ "fear_greed_label": "Extreme Fear",
+ "vix": 35.0,
+ "fed_stance": "neutral",
+ "market_regime": "risk_off",
+ "updated_at": datetime.now(UTC),
+ }
+ )
+
+ selector = StockSelector(
+ db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test"
+ )
+ result = await selector.select()
+ assert result == []
diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py
index debab1f..0d9607a 100644
--- a/services/strategy-engine/tests/test_strategy_validation.py
+++ b/services/strategy-engine/tests/test_strategy_validation.py
@@ -1,13 +1,11 @@
import pytest
-
-from strategies.rsi_strategy import RsiStrategy
-from strategies.macd_strategy import MacdStrategy
from strategies.bollinger_strategy import BollingerStrategy
from strategies.ema_crossover_strategy import EmaCrossoverStrategy
from strategies.grid_strategy import GridStrategy
-from strategies.vwap_strategy import VwapStrategy
+from strategies.macd_strategy import MacdStrategy
+from strategies.rsi_strategy import RsiStrategy
from strategies.volume_profile_strategy import VolumeProfileStrategy
-
+from strategies.vwap_strategy import VwapStrategy
# ── RSI ──────────────────────────────────────────────────────────────────
diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py
index 65ee2e8..f47898c 100644
--- a/services/strategy-engine/tests/test_volume_profile_strategy.py
+++ b/services/strategy-engine/tests/test_volume_profile_strategy.py
@@ -1,18 +1,18 @@
"""Tests for the Volume Profile strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.volume_profile_strategy import VolumeProfileStrategy
from shared.models import Candle, OrderSide
-from strategies.volume_profile_strategy import VolumeProfileStrategy
def make_candle(close: float, volume: float = 1.0) -> Candle:
return Candle(
symbol="AAPL",
timeframe="1m",
- open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open_time=datetime(2024, 1, 1, tzinfo=UTC),
open=Decimal(str(close)),
high=Decimal(str(close)),
low=Decimal(str(close)),
@@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection():
# Create a profile with very high volume at price ~100 and low volume elsewhere
# Prices range from 90 to 110, heavy volume concentrated at 100
- candles_data = []
# Low volume at extremes
- for p in [90, 91, 92, 109, 110]:
- candles_data.append((p, 1.0))
+ candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]]
# Very high volume around 100
- for _ in range(15):
- candles_data.append((100, 100.0))
+ candles_data.extend((100, 100.0) for _ in range(15))
for price, vol in candles_data:
strategy.on_candle(make_candle(price, vol))
@@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection():
# Access the internal method to verify HVN detection
result = strategy._compute_value_area()
assert result is not None
- poc, va_low, va_high, hvn_levels, lvn_levels = result
+ _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result
# The bin containing price ~100 should have very high volume -> HVN
assert len(hvn_levels) > 0
diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py
index 2c34b01..078d0cf 100644
--- a/services/strategy-engine/tests/test_vwap_strategy.py
+++ b/services/strategy-engine/tests/test_vwap_strategy.py
@@ -1,11 +1,11 @@
"""Tests for the VWAP strategy."""
-from datetime import datetime, timezone
+from datetime import UTC, datetime
from decimal import Decimal
+from strategies.vwap_strategy import VwapStrategy
from shared.models import Candle, OrderSide
-from strategies.vwap_strategy import VwapStrategy
def make_candle(
@@ -20,7 +20,7 @@ def make_candle(
if low is None:
low = close
if open_time is None:
- open_time = datetime(2024, 1, 1, tzinfo=timezone.utc)
+ open_time = datetime(2024, 1, 1, tzinfo=UTC)
return Candle(
symbol="AAPL",
timeframe="1m",
@@ -111,11 +111,11 @@ def test_vwap_daily_reset():
"""Candles from two different dates cause VWAP to reset."""
strategy = _configured_strategy()
- day1 = datetime(2024, 1, 1, tzinfo=timezone.utc)
- day2 = datetime(2024, 1, 2, tzinfo=timezone.utc)
+ day1 = datetime(2024, 1, 1, tzinfo=UTC)
+ day2 = datetime(2024, 1, 2, tzinfo=UTC)
# Feed 35 candles on day 1 to build VWAP state
- for i in range(35):
+ for _i in range(35):
strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1))
# Verify state is built up