diff options
Diffstat (limited to 'services/strategy-engine')
39 files changed, 713 insertions, 107 deletions
diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile index de635dc..f1484e9 100644 --- a/services/strategy-engine/Dockerfile +++ b/services/strategy-engine/Dockerfile @@ -1,9 +1,16 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/strategy-engine/ services/strategy-engine/ RUN pip install --no-cache-dir ./services/strategy-engine + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin COPY services/strategy-engine/strategies/ /app/strategies/ ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "strategy_engine.main"] diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml index 4f5b6be..e4bfb12 100644 --- a/services/strategy-engine/pyproject.toml +++ b/services/strategy-engine/pyproject.toml @@ -3,11 +3,7 @@ name = "strategy-engine" version = "0.1.0" description = "Plugin-based strategy execution engine" requires-python = ">=3.12" -dependencies = [ - "pandas>=2.0", - "numpy>=1.20", - "trading-shared", -] +dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index e3a49c2..9fd9c49 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -4,6 +4,6 @@ from shared.config import Settings class StrategyConfig(Settings): - symbols: list[str] = ["BTC/USDT"] + symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] timeframes: list[str] = ["1m"] strategy_params: dict = {} diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py index d401aee..4b2c468 100644 --- a/services/strategy-engine/src/strategy_engine/engine.py +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -2,11 +2,11 @@ import logging -from shared.broker import RedisBroker -from shared.events import CandleEvent, SignalEvent, Event - from strategies.base import BaseStrategy +from shared.broker import RedisBroker +from shared.events import CandleEvent, Event, SignalEvent + logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class StrategyEngine: try: event = Event.from_dict(raw) except Exception as exc: - logger.warning("Failed to parse event: %s – %s", raw, exc) + logger.warning("Failed to parse event: %s - %s", raw, exc) continue if not isinstance(event, CandleEvent): diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 30de528..3d73058 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -1,17 +1,25 @@ """Strategy Engine Service entry point.""" import asyncio +import zoneinfo +from datetime import datetime from pathlib import Path +import aiohttp + +from shared.alpaca import AlpacaClient from shared.broker import RedisBroker +from shared.db import Database from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier - +from shared.sentiment_models import MarketSentiment +from shared.shutdown import GracefulShutdown from strategy_engine.config import StrategyConfig from strategy_engine.engine import StrategyEngine from strategy_engine.plugin_loader import load_strategies +from strategy_engine.stock_selector import StockSelector # The strategies directory lives alongside the installed package STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies" @@ -30,23 +38,74 @@ async def process_symbol(engine: StrategyEngine, stream: str, log) -> None: last_id = await engine.process_once(stream, last_id) +async def run_stock_selector( + selector: StockSelector, + notifier: TelegramNotifier, + db: Database, + config: StrategyConfig, + log, +) -> None: + """Run the stock selector once per day at the configured time.""" + et = zoneinfo.ZoneInfo("America/New_York") + + while True: + now_et = datetime.now(et) + target_hour, target_min = map(int, config.selector_final_time.split(":")) + + if now_et.hour == target_hour and now_et.minute == target_min: + log.info("stock_selector_running") + try: + selections = await selector.select() + if selections: + ms_data = await db.get_latest_market_sentiment() + ms = None + if ms_data: + ms = MarketSentiment(**ms_data) + await notifier.send_stock_selection(selections, ms) + log.info("stock_selector_complete", picks=[s.symbol for s in selections]) + else: + log.info("stock_selector_no_picks") + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("stock_selector_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("stock_selector_data_error", error=str(exc)) + except Exception as exc: + log.error("stock_selector_error", error=str(exc), exc_info=True) + await asyncio.sleep(120) # Sleep past this minute + else: + await asyncio.sleep(30) + + async def run() -> None: config = StrategyConfig() log = setup_logging("strategy-engine", config.log_level, config.log_format) metrics = ServiceMetrics("strategy_engine") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) + + db = Database(config.database_url.get_secret_value()) + await db.connect() + + alpaca = AlpacaClient( + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), + paper=config.alpaca_paper, + ) + strategies = load_strategies(STRATEGIES_DIR) for strategy in strategies: params = config.strategy_params.get(strategy.name, {}) strategy.configure(params) + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies]) engine = StrategyEngine(broker=broker, strategies=strategies) @@ -67,9 +126,23 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) - await asyncio.gather(*tasks) + if config.anthropic_api_key.get_secret_value(): + selector = StockSelector( + db=db, + broker=broker, + alpaca=alpaca, + anthropic_api_key=config.anthropic_api_key.get_secret_value(), + anthropic_model=config.anthropic_model, + max_picks=config.selector_max_picks, + ) + tasks.append( + asyncio.create_task(run_stock_selector(selector, notifier, db, config, log)) + ) + log.info("stock_selector_enabled", time=config.selector_final_time) + + await shutdown.wait() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "strategy-engine") raise finally: @@ -78,6 +151,8 @@ async def run() -> None: metrics.service_up.labels(service="strategy-engine").set(0) await notifier.close() await broker.close() + await alpaca.close() + await db.close() def main() -> None: diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py index 62e4160..57680db 100644 --- a/services/strategy-engine/src/strategy_engine/plugin_loader.py +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -5,7 +5,6 @@ import sys from pathlib import Path import yaml - from strategies.base import BaseStrategy diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py new file mode 100644 index 0000000..8657b93 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -0,0 +1,418 @@ +"""3-stage stock selector engine: sentiment → technical → LLM.""" + +import asyncio +import json +import logging +import re +from datetime import UTC, datetime + +import aiohttp + +from shared.alpaca import AlpacaClient +from shared.broker import RedisBroker +from shared.db import Database +from shared.models import OrderSide +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock + +logger = logging.getLogger(__name__) + +ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" + + +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" + code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) + if code_block: + raw = code_block.group(1) + else: + array_match = re.search(r"\[.*\]", text, re.DOTALL) + if array_match: + raw = array_match.group(0) + else: + raw = text.strip() + + try: + data = json.loads(raw) + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None + except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + items = _extract_json_array(text) + if items is None: + return [] + + selections = [] + for item in items: + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + + +class SentimentCandidateSource: + """Generates candidates from DB sentiment scores.""" + + def __init__(self, db: Database) -> None: + self._db = db + + async def get_candidates(self) -> list[Candidate]: + rows = await self._db.get_top_symbol_scores(limit=20) + candidates = [] + for row in rows: + composite = float(row.get("composite", 0)) + if composite == 0: + continue + candidates.append( + Candidate( + symbol=row["symbol"], + source="sentiment", + score=composite, + reason=f"composite={composite:.2f}, news_count={row.get('news_count', 0)}", + ) + ) + return candidates + + +class LLMCandidateSource: + """Generates candidates by asking Claude to analyze recent news.""" + + def __init__(self, db: Database, api_key: str, model: str) -> None: + self._db = db + self._api_key = api_key + self._model = model + + async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]: + news_items = await self._db.get_recent_news(hours=24) + if not news_items: + return [] + + headlines = [] + for item in news_items[:50]: # cap at 50 to stay within context + symbols = item.get("symbols", []) + sym_str = ", ".join(symbols) if symbols else "N/A" + headlines.append(f"[{sym_str}] {item['headline']}") + + prompt = ( + "You are a stock analyst. Given recent news headlines, identify the 5-10 most " + "actionable US stock tickers. Return ONLY a JSON array with objects having: " + "symbol (ticker), direction ('BUY' or 'SELL'), score (0-1), reason (brief).\n\n" + "Headlines:\n" + "\n".join(headlines) + ) + + own_session = session is None + if own_session: + session = aiohttp.ClientSession() + + try: + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() + + content = data.get("content", []) + text = "" + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text += block.get("text", "") + + return self._parse_candidates(text) + except Exception as e: + logger.error("LLMCandidateSource error: %s", e) + return [] + finally: + if own_session: + await session.close() + + def _parse_candidates(self, text: str) -> list[Candidate]: + items = _extract_json_array(text) + if items is None: + return [] + + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), + ) + ) + return candidates + + +def _compute_rsi(closes: list[float], period: int = 14) -> float: + """Compute RSI for the last data point.""" + if len(closes) < period + 1: + return 50.0 # neutral if insufficient data + + deltas = [closes[i] - closes[i - 1] for i in range(1, len(closes))] + gains = [d if d > 0 else 0.0 for d in deltas] + losses = [-d if d < 0 else 0.0 for d in deltas] + + avg_gain = sum(gains[:period]) / period + avg_loss = sum(losses[:period]) / period + + for i in range(period, len(deltas)): + avg_gain = (avg_gain * (period - 1) + gains[i]) / period + avg_loss = (avg_loss * (period - 1) + losses[i]) / period + + if avg_loss == 0: + return 100.0 + rs = avg_gain / avg_loss + return 100.0 - (100.0 / (1.0 + rs)) + + +class StockSelector: + """Orchestrates the 3-stage stock selection pipeline.""" + + def __init__( + self, + db: Database, + broker: RedisBroker, + alpaca: AlpacaClient, + anthropic_api_key: str, + anthropic_model: str = "claude-sonnet-4-20250514", + max_picks: int = 3, + ) -> None: + self._db = db + self._broker = broker + self._alpaca = alpaca + self._api_key = anthropic_api_key + self._model = anthropic_model + self._max_picks = max_picks + self._http_session: aiohttp.ClientSession | None = None + self._session_lock = asyncio.Lock() + + async def _ensure_session(self) -> aiohttp.ClientSession: + async with self._session_lock: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() + + async def select(self) -> list[SelectedStock]: + """Run the full 3-stage pipeline and return selected stocks.""" + # Market gate: check sentiment + sentiment_data = await self._db.get_latest_market_sentiment() + if sentiment_data is None: + logger.warning("No market sentiment data; skipping selection") + return [] + + market_sentiment = MarketSentiment(**sentiment_data) + if market_sentiment.market_regime == "risk_off": + logger.info("Market is risk_off; skipping stock selection") + return [] + + # Stage 1: gather candidates from both sources + sentiment_source = SentimentCandidateSource(self._db) + llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + + session = await self._ensure_session() + sentiment_candidates = await sentiment_source.get_candidates() + llm_candidates = await llm_source.get_candidates(session=session) + + candidates = self._merge_candidates(sentiment_candidates, llm_candidates) + if not candidates: + logger.info("No candidates found") + return [] + + # Stage 2: technical filter + filtered = await self._technical_filter(candidates) + if not filtered: + logger.info("All candidates filtered out by technical criteria") + return [] + + # Stage 3: LLM final selection + selections = await self._llm_final_select(filtered, market_sentiment) + + # Persist and publish + today = datetime.now(UTC).date() + sentiment_snapshot = { + "fear_greed": market_sentiment.fear_greed, + "market_regime": market_sentiment.market_regime, + "vix": market_sentiment.vix, + } + for stock in selections: + try: + await self._db.insert_stock_selection( + trade_date=today, + symbol=stock.symbol, + side=stock.side.value, + conviction=stock.conviction, + reason=stock.reason, + key_news=stock.key_news, + sentiment_snapshot=sentiment_snapshot, + ) + except Exception as e: + logger.error("Failed to persist selection for %s: %s", stock.symbol, e) + + try: + await self._broker.publish( + "selected_stocks", + { + "symbol": stock.symbol, + "side": stock.side.value, + "conviction": stock.conviction, + "reason": stock.reason, + "key_news": stock.key_news, + "trade_date": str(today), + }, + ) + except Exception as e: + logger.error("Failed to publish selection for %s: %s", stock.symbol, e) + + return selections + + def _merge_candidates( + self, sentiment: list[Candidate], llm: list[Candidate] + ) -> list[Candidate]: + """Deduplicate candidates by symbol, keeping the higher score.""" + by_symbol: dict[str, Candidate] = {} + for c in sentiment + llm: + existing = by_symbol.get(c.symbol) + if existing is None or c.score > existing.score: + by_symbol[c.symbol] = c + return sorted(by_symbol.values(), key=lambda c: c.score, reverse=True) + + async def _technical_filter(self, candidates: list[Candidate]) -> list[Candidate]: + """Filter candidates using RSI, EMA20, and volume criteria.""" + passed = [] + for candidate in candidates: + try: + bars = await self._alpaca.get_bars(candidate.symbol, timeframe="1Day", limit=60) + if len(bars) < 21: + logger.debug("Insufficient bars for %s", candidate.symbol) + continue + + closes = [float(b["c"]) for b in bars] + volumes = [float(b["v"]) for b in bars] + + rsi = _compute_rsi(closes) + if not (30 <= rsi <= 70): + logger.debug("%s RSI=%.1f outside 30-70", candidate.symbol, rsi) + continue + + ema20 = sum(closes[-20:]) / 20 # simple approximation + current_price = closes[-1] + if current_price <= ema20: + logger.debug( + "%s price %.2f <= EMA20 %.2f", candidate.symbol, current_price, ema20 + ) + continue + + avg_volume = sum(volumes[:-1]) / max(len(volumes) - 1, 1) + current_volume = volumes[-1] + if current_volume <= 0.5 * avg_volume: + logger.debug( + "%s volume %.0f <= 50%% avg %.0f", + candidate.symbol, + current_volume, + avg_volume, + ) + continue + + passed.append(candidate) + except Exception as e: + logger.warning("Technical filter error for %s: %s", candidate.symbol, e) + + return passed + + async def _llm_final_select( + self, candidates: list[Candidate], market_sentiment: MarketSentiment + ) -> list[SelectedStock]: + """Ask Claude to pick 2-3 stocks with rationale.""" + candidate_lines = [ + f"- {c.symbol} (source={c.source}, score={c.score:.2f}, reason={c.reason})" + for c in candidates + ] + market_context = ( + f"Fear/Greed: {market_sentiment.fear_greed} ({market_sentiment.fear_greed_label}), " + f"VIX: {market_sentiment.vix}, " + f"Fed stance: {market_sentiment.fed_stance}, " + f"Regime: {market_sentiment.market_regime}" + ) + + prompt = ( + f"You are a portfolio manager. Select 2-3 stocks for today's session.\n\n" + f"Market context: {market_context}\n\n" + f"Candidates (already passed technical filters):\n" + + "\n".join(candidate_lines) + + "\n\n" + "Return ONLY a JSON array with objects having:\n" + " symbol, side ('BUY' or 'SELL'), conviction (0-1), reason (1-2 sentences), " + "key_news (list of 1-3 relevant headlines or facts)\n" + f"Select at most {self._max_picks} stocks." + ) + + try: + session = await self._ensure_session() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() + + content = data.get("content", []) + text = "" + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text += block.get("text", "") + + return _parse_llm_selections(text)[: self._max_picks] + except Exception as e: + logger.error("LLM final select error: %s", e) + return [] diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py index d5be675..1d9d289 100644 --- a/services/strategy-engine/strategies/base.py +++ b/services/strategy-engine/strategies/base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from collections import deque from decimal import Decimal -from typing import Optional import pandas as pd @@ -102,7 +101,7 @@ class BaseStrategy(ABC): def _calculate_atr_stops( self, entry_price: Decimal, side: str - ) -> tuple[Optional[Decimal], Optional[Decimal]]: + ) -> tuple[Decimal | None, Decimal | None]: """Calculate ATR-based stop-loss and take-profit. Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data. @@ -131,7 +130,7 @@ class BaseStrategy(ABC): return sl, tp - def _apply_filters(self, signal: Signal) -> Optional[Signal]: + def _apply_filters(self, signal: Signal) -> Signal | None: """Apply all filters to a signal. Returns signal with SL/TP or None if filtered out.""" if signal is None: return None diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py index ebe7967..02ff09a 100644 --- a/services/strategy-engine/strategies/bollinger_strategy.py +++ b/services/strategy-engine/strategies/bollinger_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py index ba92485..f562918 100644 --- a/services/strategy-engine/strategies/combined_strategy.py +++ b/services/strategy-engine/strategies/combined_strategy.py @@ -2,7 +2,7 @@ from decimal import Decimal -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py index 68d0ba3..9c181f3 100644 --- a/services/strategy-engine/strategies/ema_crossover_strategy.py +++ b/services/strategy-engine/strategies/ema_crossover_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py index 283bfe5..491252e 100644 --- a/services/strategy-engine/strategies/grid_strategy.py +++ b/services/strategy-engine/strategies/grid_strategy.py @@ -1,9 +1,8 @@ from decimal import Decimal -from typing import Optional import numpy as np -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -17,7 +16,7 @@ class GridStrategy(BaseStrategy): self._grid_count: int = 5 self._quantity: Decimal = Decimal("0.01") self._grid_levels: list[float] = [] - self._last_zone: Optional[int] = None + self._last_zone: int | None = None self._exit_threshold_pct: float = 5.0 self._out_of_range: bool = False self._in_position: bool = False # Track if we have any grid positions diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py index 3c713e6..01637b7 100644 --- a/services/strategy-engine/strategies/indicators/__init__.py +++ b/services/strategy-engine/strategies/indicators/__init__.py @@ -1,21 +1,21 @@ """Reusable technical indicator functions.""" -from strategies.indicators.trend import ema, sma, macd, adx -from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels from strategies.indicators.momentum import rsi, stochastic -from strategies.indicators.volume import volume_sma, volume_ratio, obv +from strategies.indicators.trend import adx, ema, macd, sma +from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels +from strategies.indicators.volume import obv, volume_ratio, volume_sma __all__ = [ - "ema", - "sma", - "macd", "adx", "atr", "bollinger_bands", + "ema", "keltner_channels", + "macd", + "obv", "rsi", + "sma", "stochastic", - "volume_sma", "volume_ratio", - "obv", + "volume_sma", ] diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py index c479452..a82210b 100644 --- a/services/strategy-engine/strategies/indicators/momentum.py +++ b/services/strategy-engine/strategies/indicators/momentum.py @@ -1,7 +1,7 @@ """Momentum indicators: RSI, Stochastic.""" -import pandas as pd import numpy as np +import pandas as pd def rsi(closes: pd.Series, period: int = 14) -> pd.Series: diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py index c94a071..1085199 100644 --- a/services/strategy-engine/strategies/indicators/trend.py +++ b/services/strategy-engine/strategies/indicators/trend.py @@ -1,7 +1,7 @@ """Trend indicators: EMA, SMA, MACD, ADX.""" -import pandas as pd import numpy as np +import pandas as pd def sma(series: pd.Series, period: int) -> pd.Series: diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py index c16143e..da82f26 100644 --- a/services/strategy-engine/strategies/indicators/volatility.py +++ b/services/strategy-engine/strategies/indicators/volatility.py @@ -1,7 +1,7 @@ """Volatility indicators: ATR, Bollinger Bands, Keltner Channels.""" -import pandas as pd import numpy as np +import pandas as pd def atr( diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py index 502f1ce..d7c6471 100644 --- a/services/strategy-engine/strategies/indicators/volume.py +++ b/services/strategy-engine/strategies/indicators/volume.py @@ -1,7 +1,7 @@ """Volume indicators: Volume SMA, Volume Ratio, OBV.""" -import pandas as pd import numpy as np +import pandas as pd def volume_sma(volumes: pd.Series, period: int = 20) -> pd.Series: diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py index 356a42b..b5aea07 100644 --- a/services/strategy-engine/strategies/macd_strategy.py +++ b/services/strategy-engine/strategies/macd_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py index 7eaa59e..cbc8440 100644 --- a/services/strategy-engine/strategies/moc_strategy.py +++ b/services/strategy-engine/strategies/moc_strategy.py @@ -8,12 +8,12 @@ Rules: """ from collections import deque -from decimal import Decimal from datetime import datetime +from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py index 0646d8c..2df080d 100644 --- a/services/strategy-engine/strategies/rsi_strategy.py +++ b/services/strategy-engine/strategies/rsi_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py index ef2ae14..67b5c23 100644 --- a/services/strategy-engine/strategies/volume_profile_strategy.py +++ b/services/strategy-engine/strategies/volume_profile_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import numpy as np -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -137,7 +137,7 @@ class VolumeProfileStrategy(BaseStrategy): if result is None: return None - poc, va_low, va_high, hvn_levels, lvn_levels = result + poc, va_low, va_high, hvn_levels, _lvn_levels = result if close < va_low: self._was_below_va = True diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py index d64950e..4ee4952 100644 --- a/services/strategy-engine/strategies/vwap_strategy.py +++ b/services/strategy-engine/strategies/vwap_strategy.py @@ -1,7 +1,7 @@ from collections import deque from decimal import Decimal -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -107,7 +107,7 @@ class VwapStrategy(BaseStrategy): # Standard deviation of (TP - VWAP) for bands std_dev = 0.0 if len(self._tp_values) >= 2: - diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)] + diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values, strict=True)] mean_diff = sum(diffs) / len(diffs) variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs) std_dev = variance**0.5 diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py index eb31b23..2b909ef 100644 --- a/services/strategy-engine/tests/conftest.py +++ b/services/strategy-engine/tests/conftest.py @@ -7,3 +7,8 @@ from pathlib import Path STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" if str(STRATEGIES_DIR) not in sys.path: sys.path.insert(0, str(STRATEGIES_DIR.parent)) + +# Ensure the worktree's strategy_engine src is preferred over any installed version +WORKTREE_SRC = Path(__file__).parent.parent / "src" +if str(WORKTREE_SRC) not in sys.path: + sys.path.insert(0, str(WORKTREE_SRC)) diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py index ae9ca05..66adec7 100644 --- a/services/strategy-engine/tests/test_base_filters.py +++ b/services/strategy-engine/tests/test_base_filters.py @@ -5,12 +5,13 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -from shared.models import Candle, Signal, OrderSide from strategies.base import BaseStrategy +from shared.models import Candle, OrderSide, Signal + class DummyStrategy(BaseStrategy): name = "dummy" @@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None): return Candle( symbol="AAPL", timeframe="1h", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(h)), low=Decimal(str(lo)), diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py index 8261377..70ec66e 100644 --- a/services/strategy-engine/tests/test_bollinger_strategy.py +++ b/services/strategy-engine/tests/test_bollinger_strategy.py @@ -1,18 +1,18 @@ """Tests for the Bollinger Bands strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.bollinger_strategy import BollingerStrategy from shared.models import Candle, OrderSide -from strategies.bollinger_strategy import BollingerStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py index 8a4dc74..6a15250 100644 --- a/services/strategy-engine/tests/test_combined_strategy.py +++ b/services/strategy-engine/tests/test_combined_strategy.py @@ -5,13 +5,14 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -import pytest -from shared.models import Candle, Signal, OrderSide -from strategies.combined_strategy import CombinedStrategy +import pytest from strategies.base import BaseStrategy +from strategies.combined_strategy import CombinedStrategy + +from shared.models import Candle, OrderSide, Signal class AlwaysBuyStrategy(BaseStrategy): @@ -74,7 +75,7 @@ def _candle(price=100.0): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price + 10)), low=Decimal(str(price - 10)), diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py index 7028eb0..af2b587 100644 --- a/services/strategy-engine/tests/test_ema_crossover_strategy.py +++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py @@ -1,18 +1,18 @@ """Tests for the EMA Crossover strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.ema_crossover_strategy import EmaCrossoverStrategy from shared.models import Candle, OrderSide -from strategies.ema_crossover_strategy import EmaCrossoverStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py index 2623027..fa888b5 100644 --- a/services/strategy-engine/tests/test_engine.py +++ b/services/strategy-engine/tests/test_engine.py @@ -1,21 +1,21 @@ """Tests for the StrategyEngine.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from unittest.mock import AsyncMock, MagicMock import pytest +from strategy_engine.engine import StrategyEngine -from shared.models import Candle, Signal, OrderSide from shared.events import CandleEvent -from strategy_engine.engine import StrategyEngine +from shared.models import Candle, OrderSide, Signal def make_candle_event() -> dict: candle = Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("50100"), low=Decimal("49900"), diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py index 878b900..f697012 100644 --- a/services/strategy-engine/tests/test_grid_strategy.py +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -1,18 +1,18 @@ """Tests for the Grid strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.grid_strategy import GridStrategy from shared.models import Candle, OrderSide -from strategies.grid_strategy import GridStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py index 481569b..3147fc4 100644 --- a/services/strategy-engine/tests/test_indicators.py +++ b/services/strategy-engine/tests/test_indicators.py @@ -5,14 +5,13 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -import pandas as pd import numpy as np +import pandas as pd import pytest - -from strategies.indicators.trend import sma, ema, macd, adx -from strategies.indicators.volatility import atr, bollinger_bands from strategies.indicators.momentum import rsi, stochastic -from strategies.indicators.volume import volume_sma, volume_ratio, obv +from strategies.indicators.trend import adx, ema, macd, sma +from strategies.indicators.volatility import atr, bollinger_bands +from strategies.indicators.volume import obv, volume_ratio, volume_sma class TestTrend: diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py index 556fd4c..7fac16f 100644 --- a/services/strategy-engine/tests/test_macd_strategy.py +++ b/services/strategy-engine/tests/test_macd_strategy.py @@ -1,18 +1,18 @@ """Tests for the MACD strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.macd_strategy import MacdStrategy from shared.models import Candle, OrderSide -from strategies.macd_strategy import MacdStrategy def _candle(price: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price)), low=Decimal(str(price)), diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py index 1928a28..076e846 100644 --- a/services/strategy-engine/tests/test_moc_strategy.py +++ b/services/strategy-engine/tests/test_moc_strategy.py @@ -5,19 +5,20 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from shared.models import Candle, OrderSide from strategies.moc_strategy import MocStrategy +from shared.models import Candle, OrderSide + def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None): op = open_price if open_price is not None else price - 1 # Default: bullish return Candle( symbol="AAPL", timeframe="5Min", - open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc), + open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC), open=Decimal(str(op)), high=Decimal(str(price + 1)), low=Decimal(str(min(op, price) - 1)), diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py index 671a9d3..922bfc2 100644 --- a/services/strategy-engine/tests/test_multi_symbol.py +++ b/services/strategy-engine/tests/test_multi_symbol.py @@ -9,11 +9,13 @@ import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime +from decimal import Decimal + from strategy_engine.engine import StrategyEngine + from shared.events import CandleEvent from shared.models import Candle -from decimal import Decimal -from datetime import datetime, timezone @pytest.mark.asyncio @@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams(): candle_btc = Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), @@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams(): candle_eth = Candle( symbol="MSFT", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal("3000"), high=Decimal("3100"), low=Decimal("2900"), diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py index 5191fc3..7bd450f 100644 --- a/services/strategy-engine/tests/test_plugin_loader.py +++ b/services/strategy-engine/tests/test_plugin_loader.py @@ -2,10 +2,8 @@ from pathlib import Path - from strategy_engine.plugin_loader import load_strategies - STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py index 6d31fd5..6c74f0b 100644 --- a/services/strategy-engine/tests/test_rsi_strategy.py +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -1,18 +1,18 @@ """Tests for the RSI strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.rsi_strategy import RsiStrategy from shared.models import Candle, OrderSide -from strategies.rsi_strategy import RsiStrategy def make_candle(close: float, idx: int = 0) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py new file mode 100644 index 0000000..76b8541 --- /dev/null +++ b/services/strategy-engine/tests/test_stock_selector.py @@ -0,0 +1,111 @@ +"""Tests for stock selector engine.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +from strategy_engine.stock_selector import ( + SentimentCandidateSource, + StockSelector, + _extract_json_array, + _parse_llm_selections, +) + + +async def test_sentiment_candidate_source(): + mock_db = MagicMock() + mock_db.get_top_symbol_scores = AsyncMock( + return_value=[ + {"symbol": "AAPL", "composite": 0.8, "news_count": 5}, + {"symbol": "NVDA", "composite": 0.6, "news_count": 3}, + ] + ) + + source = SentimentCandidateSource(mock_db) + candidates = await source.get_candidates() + + assert len(candidates) == 2 + assert candidates[0].symbol == "AAPL" + assert candidates[0].source == "sentiment" + + +def test_parse_llm_selections_valid(): + llm_response = """ + [ + {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]}, + {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]} + ] + """ + selections = _parse_llm_selections(llm_response) + assert len(selections) == 2 + assert selections[0].symbol == "NVDA" + assert selections[0].conviction == 0.85 + + +def test_parse_llm_selections_invalid(): + selections = _parse_llm_selections("not json") + assert selections == [] + + +def test_parse_llm_selections_with_markdown(): + llm_response = """ + Here are my picks: + ```json + [ + {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]} + ] + ``` + """ + selections = _parse_llm_selections(llm_response) + assert len(selections) == 1 + assert selections[0].symbol == "TSLA" + + +def test_extract_json_array_from_markdown(): + text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL", "score": 0.9}] + + +def test_extract_json_array_bare(): + text = '[{"symbol": "TSLA"}]' + result = _extract_json_array(text) + assert result == [{"symbol": "TSLA"}] + + +def test_extract_json_array_invalid(): + assert _extract_json_array("not json") is None + + +def test_extract_json_array_filters_non_dicts(): + text = '[{"symbol": "AAPL"}, "bad", 42]' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL"}] + + +async def test_selector_close(): + selector = StockSelector( + db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + # No session yet - close should be safe + await selector.close() + assert selector._http_session is None + + +async def test_selector_blocks_on_risk_off(): + mock_db = MagicMock() + mock_db.get_latest_market_sentiment = AsyncMock( + return_value={ + "fear_greed": 15, + "fear_greed_label": "Extreme Fear", + "vix": 35.0, + "fed_stance": "neutral", + "market_regime": "risk_off", + "updated_at": datetime.now(UTC), + } + ) + + selector = StockSelector( + db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + result = await selector.select() + assert result == [] diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py index debab1f..0d9607a 100644 --- a/services/strategy-engine/tests/test_strategy_validation.py +++ b/services/strategy-engine/tests/test_strategy_validation.py @@ -1,13 +1,11 @@ import pytest - -from strategies.rsi_strategy import RsiStrategy -from strategies.macd_strategy import MacdStrategy from strategies.bollinger_strategy import BollingerStrategy from strategies.ema_crossover_strategy import EmaCrossoverStrategy from strategies.grid_strategy import GridStrategy -from strategies.vwap_strategy import VwapStrategy +from strategies.macd_strategy import MacdStrategy +from strategies.rsi_strategy import RsiStrategy from strategies.volume_profile_strategy import VolumeProfileStrategy - +from strategies.vwap_strategy import VwapStrategy # ── RSI ────────────────────────────────────────────────────────────────── diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py index 65ee2e8..f47898c 100644 --- a/services/strategy-engine/tests/test_volume_profile_strategy.py +++ b/services/strategy-engine/tests/test_volume_profile_strategy.py @@ -1,18 +1,18 @@ """Tests for the Volume Profile strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.volume_profile_strategy import VolumeProfileStrategy from shared.models import Candle, OrderSide -from strategies.volume_profile_strategy import VolumeProfileStrategy def make_candle(close: float, volume: float = 1.0) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), @@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection(): # Create a profile with very high volume at price ~100 and low volume elsewhere # Prices range from 90 to 110, heavy volume concentrated at 100 - candles_data = [] # Low volume at extremes - for p in [90, 91, 92, 109, 110]: - candles_data.append((p, 1.0)) + candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]] # Very high volume around 100 - for _ in range(15): - candles_data.append((100, 100.0)) + candles_data.extend((100, 100.0) for _ in range(15)) for price, vol in candles_data: strategy.on_candle(make_candle(price, vol)) @@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection(): # Access the internal method to verify HVN detection result = strategy._compute_value_area() assert result is not None - poc, va_low, va_high, hvn_levels, lvn_levels = result + _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result # The bin containing price ~100 should have very high volume -> HVN assert len(hvn_levels) > 0 diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py index 2c34b01..078d0cf 100644 --- a/services/strategy-engine/tests/test_vwap_strategy.py +++ b/services/strategy-engine/tests/test_vwap_strategy.py @@ -1,11 +1,11 @@ """Tests for the VWAP strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.vwap_strategy import VwapStrategy from shared.models import Candle, OrderSide -from strategies.vwap_strategy import VwapStrategy def make_candle( @@ -20,7 +20,7 @@ def make_candle( if low is None: low = close if open_time is None: - open_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + open_time = datetime(2024, 1, 1, tzinfo=UTC) return Candle( symbol="AAPL", timeframe="1m", @@ -111,11 +111,11 @@ def test_vwap_daily_reset(): """Candles from two different dates cause VWAP to reset.""" strategy = _configured_strategy() - day1 = datetime(2024, 1, 1, tzinfo=timezone.utc) - day2 = datetime(2024, 1, 2, tzinfo=timezone.utc) + day1 = datetime(2024, 1, 1, tzinfo=UTC) + day2 = datetime(2024, 1, 2, tzinfo=UTC) # Feed 35 candles on day 1 to build VWAP state - for i in range(35): + for _i in range(35): strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1)) # Verify state is built up |
