summaryrefslogtreecommitdiff
path: root/services/strategy-engine/src/strategy_engine/stock_selector.py
diff options
context:
space:
mode:
Diffstat (limited to 'services/strategy-engine/src/strategy_engine/stock_selector.py')
-rw-r--r--services/strategy-engine/src/strategy_engine/stock_selector.py203
1 files changed, 107 insertions, 96 deletions
diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py
index 268d557..cbd9810 100644
--- a/services/strategy-engine/src/strategy_engine/stock_selector.py
+++ b/services/strategy-engine/src/strategy_engine/stock_selector.py
@@ -18,18 +18,12 @@ logger = logging.getLogger(__name__)
ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
-def _parse_llm_selections(text: str) -> list[SelectedStock]:
- """Parse LLM response into SelectedStock list.
-
- Handles both bare JSON arrays and markdown code blocks.
- Returns empty list on any parse error.
- """
- # Try to extract JSON from markdown code block first
+def _extract_json_array(text: str) -> list[dict] | None:
+ """Extract a JSON array from text that may contain markdown code blocks."""
code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
if code_block:
raw = code_block.group(1)
else:
- # Try to find a bare JSON array
array_match = re.search(r"\[.*\]", text, re.DOTALL)
if array_match:
raw = array_match.group(0)
@@ -38,27 +32,38 @@ def _parse_llm_selections(text: str) -> list[SelectedStock]:
try:
data = json.loads(raw)
- if not isinstance(data, list):
- return []
- selections = []
- for item in data:
- if not isinstance(item, dict):
- continue
- try:
- selection = SelectedStock(
- symbol=item["symbol"],
- side=OrderSide(item["side"]),
- conviction=float(item["conviction"]),
- reason=item.get("reason", ""),
- key_news=item.get("key_news", []),
- )
- selections.append(selection)
- except (KeyError, ValueError) as e:
- logger.warning("Skipping invalid selection item: %s", e)
- return selections
+ if isinstance(data, list):
+ return [item for item in data if isinstance(item, dict)]
+ return None
except (json.JSONDecodeError, TypeError):
+ return None
+
+
+def _parse_llm_selections(text: str) -> list[SelectedStock]:
+ """Parse LLM response into SelectedStock list.
+
+ Handles both bare JSON arrays and markdown code blocks.
+ Returns empty list on any parse error.
+ """
+ items = _extract_json_array(text)
+ if items is None:
return []
+ selections = []
+ for item in items:
+ try:
+ selection = SelectedStock(
+ symbol=item["symbol"],
+ side=OrderSide(item["side"]),
+ conviction=float(item["conviction"]),
+ reason=item.get("reason", ""),
+ key_news=item.get("key_news", []),
+ )
+ selections.append(selection)
+ except (KeyError, ValueError) as e:
+ logger.warning("Skipping invalid selection item: %s", e)
+ return selections
+
class SentimentCandidateSource:
"""Generates candidates from DB sentiment scores."""
@@ -92,7 +97,7 @@ class LLMCandidateSource:
self._api_key = api_key
self._model = model
- async def get_candidates(self) -> list[Candidate]:
+ async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]:
news_items = await self._db.get_recent_news(hours=24)
if not news_items:
return []
@@ -110,26 +115,29 @@ class LLMCandidateSource:
"Headlines:\n" + "\n".join(headlines)
)
+ own_session = session is None
+ if own_session:
+ session = aiohttp.ClientSession()
+
try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- ANTHROPIC_API_URL,
- headers={
- "x-api-key": self._api_key,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- json={
- "model": self._model,
- "max_tokens": 1024,
- "messages": [{"role": "user", "content": prompt}],
- },
- ) as resp:
- if resp.status != 200:
- body = await resp.text()
- logger.error("LLM candidate source error %d: %s", resp.status, body)
- return []
- data = await resp.json()
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM candidate source error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
content = data.get("content", [])
text = ""
@@ -141,40 +149,32 @@ class LLMCandidateSource:
except Exception as e:
logger.error("LLMCandidateSource error: %s", e)
return []
+ finally:
+ if own_session:
+ await session.close()
def _parse_candidates(self, text: str) -> list[Candidate]:
- code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
- if code_block:
- raw = code_block.group(1)
- else:
- array_match = re.search(r"\[.*\]", text, re.DOTALL)
- raw = array_match.group(0) if array_match else text.strip()
+ items = _extract_json_array(text)
+ if items is None:
+ return []
- try:
- items = json.loads(raw)
- if not isinstance(items, list):
- return []
- candidates = []
- for item in items:
- if not isinstance(item, dict):
- continue
- try:
- direction_str = item.get("direction", "BUY")
- direction = OrderSide(direction_str)
- except ValueError:
- direction = None
- candidates.append(
- Candidate(
- symbol=item["symbol"],
- source="llm",
- direction=direction,
- score=float(item.get("score", 0.5)),
- reason=item.get("reason", ""),
- )
+ candidates = []
+ for item in items:
+ try:
+ direction_str = item.get("direction", "BUY")
+ direction = OrderSide(direction_str)
+ except ValueError:
+ direction = None
+ candidates.append(
+ Candidate(
+ symbol=item["symbol"],
+ source="llm",
+ direction=direction,
+ score=float(item.get("score", 0.5)),
+ reason=item.get("reason", ""),
)
- return candidates
- except (json.JSONDecodeError, TypeError, KeyError):
- return []
+ )
+ return candidates
def _compute_rsi(closes: list[float], period: int = 14) -> float:
@@ -217,6 +217,16 @@ class StockSelector:
self._api_key = anthropic_api_key
self._model = anthropic_model
self._max_picks = max_picks
+ self._http_session: aiohttp.ClientSession | None = None
+
+ async def _ensure_session(self) -> aiohttp.ClientSession:
+ if self._http_session is None or self._http_session.closed:
+ self._http_session = aiohttp.ClientSession()
+ return self._http_session
+
+ async def close(self) -> None:
+ if self._http_session and not self._http_session.closed:
+ await self._http_session.close()
async def select(self) -> list[SelectedStock]:
"""Run the full 3-stage pipeline and return selected stocks."""
@@ -235,8 +245,9 @@ class StockSelector:
sentiment_source = SentimentCandidateSource(self._db)
llm_source = LLMCandidateSource(self._db, self._api_key, self._model)
+ session = await self._ensure_session()
sentiment_candidates = await sentiment_source.get_candidates()
- llm_candidates = await llm_source.get_candidates()
+ llm_candidates = await llm_source.get_candidates(session=session)
candidates = self._merge_candidates(sentiment_candidates, llm_candidates)
if not candidates:
@@ -372,25 +383,25 @@ class StockSelector:
)
try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- ANTHROPIC_API_URL,
- headers={
- "x-api-key": self._api_key,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- json={
- "model": self._model,
- "max_tokens": 1024,
- "messages": [{"role": "user", "content": prompt}],
- },
- ) as resp:
- if resp.status != 200:
- body = await resp.text()
- logger.error("LLM final select error %d: %s", resp.status, body)
- return []
- data = await resp.json()
+ session = await self._ensure_session()
+ async with session.post(
+ ANTHROPIC_API_URL,
+ headers={
+ "x-api-key": self._api_key,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ json={
+ "model": self._model,
+ "max_tokens": 1024,
+ "messages": [{"role": "user", "content": prompt}],
+ },
+ ) as resp:
+ if resp.status != 200:
+ body = await resp.text()
+ logger.error("LLM final select error %d: %s", resp.status, body)
+ return []
+ data = await resp.json()
content = data.get("content", [])
text = ""