From 8da5fb843856bb6585c6753f44d422beaa4a8204 Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:42:23 +0900 Subject: fix: deduplicate LLM JSON parsing and reuse aiohttp sessions in stock selector Extract _extract_json_array() to eliminate duplicate JSON parsing logic between _parse_llm_selections() and LLMCandidateSource._parse_candidates(). Add session reuse in StockSelector via _ensure_session()/close() methods instead of creating new aiohttp.ClientSession per HTTP call. Pass shared session to LLMCandidateSource.get_candidates(). --- .../src/strategy_engine/stock_selector.py | 203 +++++++++++---------- 1 file changed, 107 insertions(+), 96 deletions(-) (limited to 'services/strategy-engine/src/strategy_engine') diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py index 268d557..cbd9810 100644 --- a/services/strategy-engine/src/strategy_engine/stock_selector.py +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -18,18 +18,12 @@ logger = logging.getLogger(__name__) ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" -def _parse_llm_selections(text: str) -> list[SelectedStock]: - """Parse LLM response into SelectedStock list. - - Handles both bare JSON arrays and markdown code blocks. - Returns empty list on any parse error. - """ - # Try to extract JSON from markdown code block first +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) if code_block: raw = code_block.group(1) else: - # Try to find a bare JSON array array_match = re.search(r"\[.*\]", text, re.DOTALL) if array_match: raw = array_match.group(0) @@ -38,27 +32,38 @@ def _parse_llm_selections(text: str) -> list[SelectedStock]: try: data = json.loads(raw) - if not isinstance(data, list): - return [] - selections = [] - for item in data: - if not isinstance(item, dict): - continue - try: - selection = SelectedStock( - symbol=item["symbol"], - side=OrderSide(item["side"]), - conviction=float(item["conviction"]), - reason=item.get("reason", ""), - key_news=item.get("key_news", []), - ) - selections.append(selection) - except (KeyError, ValueError) as e: - logger.warning("Skipping invalid selection item: %s", e) - return selections + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + items = _extract_json_array(text) + if items is None: return [] + selections = [] + for item in items: + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + class SentimentCandidateSource: """Generates candidates from DB sentiment scores.""" @@ -92,7 +97,7 @@ class LLMCandidateSource: self._api_key = api_key self._model = model - async def get_candidates(self) -> list[Candidate]: + async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]: news_items = await self._db.get_recent_news(hours=24) if not news_items: return [] @@ -110,26 +115,29 @@ class LLMCandidateSource: "Headlines:\n" + "\n".join(headlines) ) + own_session = session is None + if own_session: + session = aiohttp.ClientSession() + try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM candidate source error %d: %s", resp.status, body) - return [] - data = await resp.json() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" @@ -141,40 +149,32 @@ class LLMCandidateSource: except Exception as e: logger.error("LLMCandidateSource error: %s", e) return [] + finally: + if own_session: + await session.close() def _parse_candidates(self, text: str) -> list[Candidate]: - code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) - if code_block: - raw = code_block.group(1) - else: - array_match = re.search(r"\[.*\]", text, re.DOTALL) - raw = array_match.group(0) if array_match else text.strip() + items = _extract_json_array(text) + if items is None: + return [] - try: - items = json.loads(raw) - if not isinstance(items, list): - return [] - candidates = [] - for item in items: - if not isinstance(item, dict): - continue - try: - direction_str = item.get("direction", "BUY") - direction = OrderSide(direction_str) - except ValueError: - direction = None - candidates.append( - Candidate( - symbol=item["symbol"], - source="llm", - direction=direction, - score=float(item.get("score", 0.5)), - reason=item.get("reason", ""), - ) + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), ) - return candidates - except (json.JSONDecodeError, TypeError, KeyError): - return [] + ) + return candidates def _compute_rsi(closes: list[float], period: int = 14) -> float: @@ -217,6 +217,16 @@ class StockSelector: self._api_key = anthropic_api_key self._model = anthropic_model self._max_picks = max_picks + self._http_session: aiohttp.ClientSession | None = None + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() async def select(self) -> list[SelectedStock]: """Run the full 3-stage pipeline and return selected stocks.""" @@ -235,8 +245,9 @@ class StockSelector: sentiment_source = SentimentCandidateSource(self._db) llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + session = await self._ensure_session() sentiment_candidates = await sentiment_source.get_candidates() - llm_candidates = await llm_source.get_candidates() + llm_candidates = await llm_source.get_candidates(session=session) candidates = self._merge_candidates(sentiment_candidates, llm_candidates) if not candidates: @@ -372,25 +383,25 @@ class StockSelector: ) try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM final select error %d: %s", resp.status, body) - return [] - data = await resp.json() + session = await self._ensure_session() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" -- cgit v1.2.3