diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:42:23 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:42:23 +0900 |
| commit | 8da5fb843856bb6585c6753f44d422beaa4a8204 (patch) | |
| tree | 198c0e6b44ddb0981a48332effaad1276e643886 /services | |
| parent | 0e177eafbed026445e50da6a5992177521fb8212 (diff) | |
fix: deduplicate LLM JSON parsing and reuse aiohttp sessions in stock selector
Extract _extract_json_array() to eliminate duplicate JSON parsing logic
between _parse_llm_selections() and LLMCandidateSource._parse_candidates().
Add session reuse in StockSelector via _ensure_session()/close() methods
instead of creating new aiohttp.ClientSession per HTTP call. Pass shared
session to LLMCandidateSource.get_candidates().
Diffstat (limited to 'services')
| -rw-r--r-- | services/strategy-engine/src/strategy_engine/stock_selector.py | 203 | ||||
| -rw-r--r-- | services/strategy-engine/tests/test_stock_selector.py | 32 |
2 files changed, 139 insertions, 96 deletions
diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py index 268d557..cbd9810 100644 --- a/services/strategy-engine/src/strategy_engine/stock_selector.py +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -18,18 +18,12 @@ logger = logging.getLogger(__name__) ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" -def _parse_llm_selections(text: str) -> list[SelectedStock]: - """Parse LLM response into SelectedStock list. - - Handles both bare JSON arrays and markdown code blocks. - Returns empty list on any parse error. - """ - # Try to extract JSON from markdown code block first +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) if code_block: raw = code_block.group(1) else: - # Try to find a bare JSON array array_match = re.search(r"\[.*\]", text, re.DOTALL) if array_match: raw = array_match.group(0) @@ -38,27 +32,38 @@ def _parse_llm_selections(text: str) -> list[SelectedStock]: try: data = json.loads(raw) - if not isinstance(data, list): - return [] - selections = [] - for item in data: - if not isinstance(item, dict): - continue - try: - selection = SelectedStock( - symbol=item["symbol"], - side=OrderSide(item["side"]), - conviction=float(item["conviction"]), - reason=item.get("reason", ""), - key_news=item.get("key_news", []), - ) - selections.append(selection) - except (KeyError, ValueError) as e: - logger.warning("Skipping invalid selection item: %s", e) - return selections + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + items = _extract_json_array(text) + if items is None: return [] + selections = [] + for item in items: + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + class SentimentCandidateSource: """Generates candidates from DB sentiment scores.""" @@ -92,7 +97,7 @@ class LLMCandidateSource: self._api_key = api_key self._model = model - async def get_candidates(self) -> list[Candidate]: + async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]: news_items = await self._db.get_recent_news(hours=24) if not news_items: return [] @@ -110,26 +115,29 @@ class LLMCandidateSource: "Headlines:\n" + "\n".join(headlines) ) + own_session = session is None + if own_session: + session = aiohttp.ClientSession() + try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM candidate source error %d: %s", resp.status, body) - return [] - data = await resp.json() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" @@ -141,40 +149,32 @@ class LLMCandidateSource: except Exception as e: logger.error("LLMCandidateSource error: %s", e) return [] + finally: + if own_session: + await session.close() def _parse_candidates(self, text: str) -> list[Candidate]: - code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) - if code_block: - raw = code_block.group(1) - else: - array_match = re.search(r"\[.*\]", text, re.DOTALL) - raw = array_match.group(0) if array_match else text.strip() + items = _extract_json_array(text) + if items is None: + return [] - try: - items = json.loads(raw) - if not isinstance(items, list): - return [] - candidates = [] - for item in items: - if not isinstance(item, dict): - continue - try: - direction_str = item.get("direction", "BUY") - direction = OrderSide(direction_str) - except ValueError: - direction = None - candidates.append( - Candidate( - symbol=item["symbol"], - source="llm", - direction=direction, - score=float(item.get("score", 0.5)), - reason=item.get("reason", ""), - ) + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), ) - return candidates - except (json.JSONDecodeError, TypeError, KeyError): - return [] + ) + return candidates def _compute_rsi(closes: list[float], period: int = 14) -> float: @@ -217,6 +217,16 @@ class StockSelector: self._api_key = anthropic_api_key self._model = anthropic_model self._max_picks = max_picks + self._http_session: aiohttp.ClientSession | None = None + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() async def select(self) -> list[SelectedStock]: """Run the full 3-stage pipeline and return selected stocks.""" @@ -235,8 +245,9 @@ class StockSelector: sentiment_source = SentimentCandidateSource(self._db) llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + session = await self._ensure_session() sentiment_candidates = await sentiment_source.get_candidates() - llm_candidates = await llm_source.get_candidates() + llm_candidates = await llm_source.get_candidates(session=session) candidates = self._merge_candidates(sentiment_candidates, llm_candidates) if not candidates: @@ -372,25 +383,25 @@ class StockSelector: ) try: - async with aiohttp.ClientSession() as session: - async with session.post( - ANTHROPIC_API_URL, - headers={ - "x-api-key": self._api_key, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - json={ - "model": self._model, - "max_tokens": 1024, - "messages": [{"role": "user", "content": prompt}], - }, - ) as resp: - if resp.status != 200: - body = await resp.text() - logger.error("LLM final select error %d: %s", resp.status, body) - return [] - data = await resp.json() + session = await self._ensure_session() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() content = data.get("content", []) text = "" diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py index ff9d09c..fa15f66 100644 --- a/services/strategy-engine/tests/test_stock_selector.py +++ b/services/strategy-engine/tests/test_stock_selector.py @@ -7,6 +7,7 @@ from datetime import datetime, timezone from strategy_engine.stock_selector import ( SentimentCandidateSource, StockSelector, + _extract_json_array, _parse_llm_selections, ) @@ -60,6 +61,37 @@ def test_parse_llm_selections_with_markdown(): assert selections[0].symbol == "TSLA" +def test_extract_json_array_from_markdown(): + text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL", "score": 0.9}] + + +def test_extract_json_array_bare(): + text = '[{"symbol": "TSLA"}]' + result = _extract_json_array(text) + assert result == [{"symbol": "TSLA"}] + + +def test_extract_json_array_invalid(): + assert _extract_json_array("not json") is None + + +def test_extract_json_array_filters_non_dicts(): + text = '[{"symbol": "AAPL"}, "bad", 42]' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL"}] + + +async def test_selector_close(): + selector = StockSelector( + db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + # No session yet - close should be safe + await selector.close() + assert selector._http_session is None + + async def test_selector_blocks_on_risk_off(): mock_db = MagicMock() mock_db.get_latest_market_sentiment = AsyncMock( |
