diff options
Diffstat (limited to 'services/strategy-engine/src')
5 files changed, 137 insertions, 121 deletions
diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index 2a9cb43..9fd9c49 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -7,9 +7,3 @@ class StrategyConfig(Settings): symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] timeframes: list[str] = ["1m"] strategy_params: dict = {} - selector_candidates_time: str = "15:00" - selector_filter_time: str = "15:15" - selector_final_time: str = "15:30" - selector_max_picks: int = 3 - anthropic_api_key: str = "" - anthropic_model: str = "claude-sonnet-4-20250514" diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py index d401aee..4b2c468 100644 --- a/services/strategy-engine/src/strategy_engine/engine.py +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -2,11 +2,11 @@ import logging -from shared.broker import RedisBroker -from shared.events import CandleEvent, SignalEvent, Event - from strategies.base import BaseStrategy +from shared.broker import RedisBroker +from shared.events import CandleEvent, Event, SignalEvent + logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class StrategyEngine: try: event = Event.from_dict(raw) except Exception as exc: - logger.warning("Failed to parse event: %s – %s", raw, exc) + logger.warning("Failed to parse event: %s - %s", raw, exc) continue if not isinstance(event, CandleEvent): diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 5a30766..3d73058 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -1,9 +1,11 @@ """Strategy Engine Service entry point.""" import asyncio +import zoneinfo from datetime import datetime from pathlib import Path -import zoneinfo + +import aiohttp from shared.alpaca import AlpacaClient from shared.broker import RedisBroker @@ -13,7 +15,7 @@ from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier from shared.sentiment_models import MarketSentiment - +from shared.shutdown import GracefulShutdown from strategy_engine.config import StrategyConfig from strategy_engine.engine import StrategyEngine from strategy_engine.plugin_loader import load_strategies @@ -63,8 +65,12 @@ async def run_stock_selector( log.info("stock_selector_complete", picks=[s.symbol for s in selections]) else: log.info("stock_selector_no_picks") + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("stock_selector_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("stock_selector_data_error", error=str(exc)) except Exception as exc: - log.error("stock_selector_error", error=str(exc)) + log.error("stock_selector_error", error=str(exc), exc_info=True) await asyncio.sleep(120) # Sleep past this minute else: await asyncio.sleep(30) @@ -76,18 +82,18 @@ async def run() -> None: metrics = ServiceMetrics("strategy_engine") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -97,6 +103,9 @@ async def run() -> None: params = config.strategy_params.get(strategy.name, {}) strategy.configure(params) + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies]) engine = StrategyEngine(broker=broker, strategies=strategies) @@ -117,12 +126,12 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) - if config.anthropic_api_key: + if config.anthropic_api_key.get_secret_value(): selector = StockSelector( db=db, broker=broker, alpaca=alpaca, - anthropic_api_key=config.anthropic_api_key, + anthropic_api_key=config.anthropic_api_key.get_secret_value(), anthropic_model=config.anthropic_model, max_picks=config.selector_max_picks, ) @@ -131,9 +140,9 @@ async def run() -> None: ) log.info("stock_selector_enabled", time=config.selector_final_time) - await asyncio.gather(*tasks) + await shutdown.wait() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "strategy-engine") raise finally: diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py index 62e4160..57680db 100644 --- a/services/strategy-engine/src/strategy_engine/plugin_loader.py +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -5,7 +5,6 @@ import sys from pathlib import Path import yaml - from strategies.base import BaseStrategy diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py index 268d557..8657b93 100644 --- a/services/strategy-engine/src/strategy_engine/stock_selector.py +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -1,9 +1,10 @@ """3-stage stock selector engine: sentiment → technical → LLM.""" +import asyncio import json import logging import re -from datetime import datetime, timezone +from datetime import UTC, datetime import aiohttp @@ -18,18 +19,12 @@ logger = logging.getLogger(__name__) ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" -def _parse_llm_selections(text: str) -> list[SelectedStock]: - """Parse LLM response into SelectedStock list. - - Handles both bare JSON arrays and markdown code blocks. - Returns empty list on any parse error. - """ - # Try to extract JSON from markdown code block first +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) if code_block: raw = code_block.group(1) else: - # Try to find a bare JSON array array_match = re.search(r"\[.*\]", text, re.DOTALL) if array_match: raw = array_match.group(0) @@ -38,27 +33,38 @@ def _parse_llm_selections(text: str) -> list[SelectedStock]: try: data = json.loads(raw) - if not isinstance(data, list): - return [] - selections = [] - for item in data: - if not isinstance(item, dict): - continue - try: - selection = SelectedStock( - symbol=item["symbol"], - side=OrderSide(item["side"]), - conviction=float(item["conviction"]), - reason=item.get("reason", ""), - key_news=item.get("key_news", []), - ) - selections.append(selection) - except (KeyError, ValueError) as e: - logger.warning("Skipping invalid selection item: %s", e) - return selections + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + items = _extract_json_array(text) + if items is None: return [] + selections = [] + for item in items: + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + class SentimentCandidateSource: """Generates candidates from DB sentiment scores.""" @@ -92,7 +98,7 @@ class LLMCandidateSource: self._api_key = api_key self._model = model - async def get_candidates(self) -> list[Candidate]: + async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]: news_items = await self._db.get_recent_news(hours=24) if not news_items: return [] @@ -110,26 +116,29 @@ class LLMCandidateSource: "Headlines:\n" + "\n".join(headlines) ) + own_session = session is None + if own_session: + session = aiohttp.ClientSession() + try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM candidate source error %d: %s", resp.status, body) - return [] - data = await resp.json() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" @@ -141,40 +150,32 @@ class LLMCandidateSource: except Exception as e: logger.error("LLMCandidateSource error: %s", e) return [] + finally: + if own_session: + await session.close() def _parse_candidates(self, text: str) -> list[Candidate]: - code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) - if code_block: - raw = code_block.group(1) - else: - array_match = re.search(r"\[.*\]", text, re.DOTALL) - raw = array_match.group(0) if array_match else text.strip() + items = _extract_json_array(text) + if items is None: + return [] - try: - items = json.loads(raw) - if not isinstance(items, list): - return [] - candidates = [] - for item in items: - if not isinstance(item, dict): - continue - try: - direction_str = item.get("direction", "BUY") - direction = OrderSide(direction_str) - except ValueError: - direction = None - candidates.append( - Candidate( - symbol=item["symbol"], - source="llm", - direction=direction, - score=float(item.get("score", 0.5)), - reason=item.get("reason", ""), - ) + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), ) - return candidates - except (json.JSONDecodeError, TypeError, KeyError): - return [] + ) + return candidates def _compute_rsi(closes: list[float], period: int = 14) -> float: @@ -217,6 +218,18 @@ class StockSelector: self._api_key = anthropic_api_key self._model = anthropic_model self._max_picks = max_picks + self._http_session: aiohttp.ClientSession | None = None + self._session_lock = asyncio.Lock() + + async def _ensure_session(self) -> aiohttp.ClientSession: + async with self._session_lock: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() async def select(self) -> list[SelectedStock]: """Run the full 3-stage pipeline and return selected stocks.""" @@ -235,8 +248,9 @@ class StockSelector: sentiment_source = SentimentCandidateSource(self._db) llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + session = await self._ensure_session() sentiment_candidates = await sentiment_source.get_candidates() - llm_candidates = await llm_source.get_candidates() + llm_candidates = await llm_source.get_candidates(session=session) candidates = self._merge_candidates(sentiment_candidates, llm_candidates) if not candidates: @@ -253,7 +267,7 @@ class StockSelector: selections = await self._llm_final_select(filtered, market_sentiment) # Persist and publish - today = datetime.now(timezone.utc).date() + today = datetime.now(UTC).date() sentiment_snapshot = { "fear_greed": market_sentiment.fear_greed, "market_regime": market_sentiment.market_regime, @@ -372,25 +386,25 @@ class StockSelector: ) try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM final select error %d: %s", resp.status, body) - return [] - data = await resp.json() + session = await self._ensure_session() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" |
