diff options
Diffstat (limited to 'services')
98 files changed, 2406 insertions, 297 deletions
diff --git a/services/api/Dockerfile b/services/api/Dockerfile index b942075..93d2b75 100644 --- a/services/api/Dockerfile +++ b/services/api/Dockerfile @@ -1,11 +1,18 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/api/ services/api/ RUN pip install --no-cache-dir ./services/api -COPY services/strategy-engine/strategies/ /app/strategies/ COPY services/strategy-engine/ services/strategy-engine/ RUN pip install --no-cache-dir ./services/strategy-engine -ENV PYTHONPATH=/app -CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000"] + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY services/strategy-engine/strategies/ /app/strategies/ +ENV PYTHONPATH=/app STRATEGIES_DIR=/app/strategies +USER appuser +CMD ["uvicorn", "trading_api.main:app", "--host", "0.0.0.0", "--port", "8000", "--timeout-graceful-shutdown", "30"] diff --git a/services/api/pyproject.toml b/services/api/pyproject.toml index fd2598d..95099d2 100644 --- a/services/api/pyproject.toml +++ b/services/api/pyproject.toml @@ -3,11 +3,7 @@ name = "trading-api" version = "0.1.0" description = "REST API for the trading platform" requires-python = ">=3.12" -dependencies = [ - "fastapi>=0.110", - "uvicorn>=0.27", - "trading-shared", -] +dependencies = ["fastapi>=0.110,<1", "uvicorn>=0.27,<1", "slowapi>=0.1.9,<1", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27"] diff --git a/services/api/src/trading_api/dependencies/__init__.py b/services/api/src/trading_api/dependencies/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/api/src/trading_api/dependencies/__init__.py diff --git a/services/api/src/trading_api/dependencies/auth.py b/services/api/src/trading_api/dependencies/auth.py new file mode 100644 index 0000000..a5e76c1 --- /dev/null +++ b/services/api/src/trading_api/dependencies/auth.py @@ -0,0 +1,29 @@ +"""Bearer token authentication dependency.""" + +import logging + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from shared.config import Settings + +logger = logging.getLogger(__name__) + +_security = HTTPBearer(auto_error=False) +_settings = Settings() + + +async def verify_token( + credentials: HTTPAuthorizationCredentials | None = Depends(_security), +) -> None: + """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured.""" + token = _settings.api_auth_token.get_secret_value() + if not token: + return # Auth disabled in dev mode + + if credentials is None or credentials.credentials != token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py index 39f7b43..05c6d2f 100644 --- a/services/api/src/trading_api/main.py +++ b/services/api/src/trading_api/main.py @@ -1,33 +1,71 @@ """Trading Platform REST API.""" +import logging from contextlib import asynccontextmanager -from fastapi import FastAPI +from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address from shared.config import Settings from shared.db import Database +from trading_api.dependencies.auth import verify_token +from trading_api.routers import orders, portfolio, strategies -from trading_api.routers import portfolio, orders, strategies +logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): settings = Settings() - app.state.db = Database(settings.database_url) + if not settings.api_auth_token.get_secret_value(): + logger.warning("API_AUTH_TOKEN not set — authentication is disabled") + app.state.db = Database(settings.database_url.get_secret_value()) await app.state.db.connect() yield await app.state.db.close() +cfg = Settings() + +limiter = Limiter(key_func=get_remote_address) + app = FastAPI( title="Trading Platform API", version="0.1.0", lifespan=lifespan, ) -app.include_router(portfolio.router, prefix="/api/v1/portfolio", tags=["portfolio"]) -app.include_router(orders.router, prefix="/api/v1/orders", tags=["orders"]) -app.include_router(strategies.router, prefix="/api/v1/strategies", tags=["strategies"]) +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +app.add_middleware( + CORSMiddleware, + allow_origins=cfg.cors_origins.split(","), + allow_methods=["GET", "POST"], + allow_headers=["Authorization", "Content-Type"], +) + +app.include_router( + portfolio.router, + prefix="/api/v1/portfolio", + tags=["portfolio"], + dependencies=[Depends(verify_token)], +) +app.include_router( + orders.router, + prefix="/api/v1/orders", + tags=["orders"], + dependencies=[Depends(verify_token)], +) +app.include_router( + strategies.router, + prefix="/api/v1/strategies", + tags=["strategies"], + dependencies=[Depends(verify_token)], +) @app.get("/health") diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py index c69dc10..b664e2a 100644 --- a/services/api/src/trading_api/routers/orders.py +++ b/services/api/src/trading_api/routers/orders.py @@ -2,17 +2,23 @@ import logging -from fastapi import APIRouter, HTTPException, Request -from shared.sa_models import OrderRow, SignalRow +from fastapi import APIRouter, HTTPException, Query, Request +from slowapi import Limiter +from slowapi.util import get_remote_address from sqlalchemy import select +from sqlalchemy.exc import OperationalError + +from shared.sa_models import OrderRow, SignalRow logger = logging.getLogger(__name__) router = APIRouter() +limiter = Limiter(key_func=get_remote_address) @router.get("/") -async def get_orders(request: Request, limit: int = 50): +@limiter.limit("60/minute") +async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)): """Get recent orders.""" try: db = request.app.state.db @@ -35,13 +41,17 @@ async def get_orders(request: Request, limit: int = 50): } for r in rows ] + except OperationalError as exc: + logger.error("Database error fetching orders: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get orders: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve orders") + logger.error("Failed to get orders: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve orders") from exc @router.get("/signals") -async def get_signals(request: Request, limit: int = 50): +@limiter.limit("60/minute") +async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)): """Get recent signals.""" try: db = request.app.state.db @@ -62,6 +72,9 @@ async def get_signals(request: Request, limit: int = 50): } for r in rows ] + except OperationalError as exc: + logger.error("Database error fetching signals: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get signals: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve signals") + logger.error("Failed to get signals: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve signals") from exc diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py index d76d85d..56bee7c 100644 --- a/services/api/src/trading_api/routers/portfolio.py +++ b/services/api/src/trading_api/routers/portfolio.py @@ -2,9 +2,11 @@ import logging -from fastapi import APIRouter, HTTPException, Request -from shared.sa_models import PositionRow +from fastapi import APIRouter, HTTPException, Query, Request from sqlalchemy import select +from sqlalchemy.exc import OperationalError + +from shared.sa_models import PositionRow logger = logging.getLogger(__name__) @@ -29,13 +31,16 @@ async def get_positions(request: Request): } for r in rows ] + except OperationalError as exc: + logger.error("Database error fetching positions: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get positions: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve positions") + logger.error("Failed to get positions: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve positions") from exc @router.get("/snapshots") -async def get_snapshots(request: Request, days: int = 30): +async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)): """Get portfolio snapshots for the last N days.""" try: db = request.app.state.db @@ -49,6 +54,9 @@ async def get_snapshots(request: Request, days: int = 30): } for s in snapshots ] + except OperationalError as exc: + logger.error("Database error fetching snapshots: %s", exc) + raise HTTPException(status_code=503, detail="Database unavailable") from exc except Exception as exc: - logger.error("Failed to get snapshots: %s", exc) - raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") + logger.error("Failed to get snapshots: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to retrieve snapshots") from exc diff --git a/services/api/src/trading_api/routers/strategies.py b/services/api/src/trading_api/routers/strategies.py index 7ddd54e..157094c 100644 --- a/services/api/src/trading_api/routers/strategies.py +++ b/services/api/src/trading_api/routers/strategies.py @@ -42,6 +42,9 @@ async def list_strategies(): } for s in strategies ] + except (ImportError, FileNotFoundError) as exc: + logger.error("Strategy loading error: %s", exc) + raise HTTPException(status_code=503, detail="Strategy engine unavailable") from exc except Exception as exc: - logger.error("Failed to list strategies: %s", exc) - raise HTTPException(status_code=500, detail="Failed to list strategies") + logger.error("Failed to list strategies: %s", exc, exc_info=True) + raise HTTPException(status_code=500, detail="Failed to list strategies") from exc diff --git a/services/api/tests/test_api.py b/services/api/tests/test_api.py index 669143b..f3b0a47 100644 --- a/services/api/tests/test_api.py +++ b/services/api/tests/test_api.py @@ -1,6 +1,7 @@ """Tests for the REST API.""" from unittest.mock import AsyncMock, patch + from fastapi.testclient import TestClient diff --git a/services/api/tests/test_orders_router.py b/services/api/tests/test_orders_router.py index 0658619..52252c5 100644 --- a/services/api/tests/test_orders_router.py +++ b/services/api/tests/test_orders_router.py @@ -1,10 +1,10 @@ """Tests for orders API router.""" -import pytest from unittest.mock import AsyncMock, MagicMock -from fastapi.testclient import TestClient -from fastapi import FastAPI +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient from trading_api.routers.orders import router diff --git a/services/api/tests/test_portfolio_router.py b/services/api/tests/test_portfolio_router.py index f2584ea..8cd8ff8 100644 --- a/services/api/tests/test_portfolio_router.py +++ b/services/api/tests/test_portfolio_router.py @@ -1,11 +1,11 @@ """Tests for portfolio API router.""" -import pytest from decimal import Decimal from unittest.mock import AsyncMock, MagicMock -from fastapi.testclient import TestClient -from fastapi import FastAPI +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient from trading_api.routers.portfolio import router @@ -45,7 +45,7 @@ def test_get_positions_with_data(app, mock_db): app.state.db = db mock_row = MagicMock() - mock_row.symbol = "BTCUSDT" + mock_row.symbol = "AAPL" mock_row.quantity = Decimal("0.1") mock_row.avg_entry_price = Decimal("50000") mock_row.current_price = Decimal("55000") @@ -59,7 +59,7 @@ def test_get_positions_with_data(app, mock_db): assert response.status_code == 200 data = response.json() assert len(data) == 1 - assert data[0]["symbol"] == "BTCUSDT" + assert data[0]["symbol"] == "AAPL" def test_get_snapshots_empty(app, mock_db): diff --git a/services/backtester/Dockerfile b/services/backtester/Dockerfile index 9a4f439..1108e42 100644 --- a/services/backtester/Dockerfile +++ b/services/backtester/Dockerfile @@ -1,10 +1,17 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/backtester/ services/backtester/ RUN pip install --no-cache-dir ./services/backtester + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin COPY services/strategy-engine/strategies/ /app/strategies/ ENV STRATEGIES_DIR=/app/strategies ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "backtester.main"] diff --git a/services/backtester/pyproject.toml b/services/backtester/pyproject.toml index 2601d04..034bcf6 100644 --- a/services/backtester/pyproject.toml +++ b/services/backtester/pyproject.toml @@ -3,7 +3,7 @@ name = "backtester" version = "0.1.0" description = "Strategy backtesting engine" requires-python = ">=3.12" -dependencies = ["pandas>=2.0", "numpy>=1.20", "rich>=13.0", "trading-shared"] +dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "rich>=13.0,<14", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/backtester/src/backtester/config.py b/services/backtester/src/backtester/config.py index f7897da..57ee1fb 100644 --- a/services/backtester/src/backtester/config.py +++ b/services/backtester/src/backtester/config.py @@ -5,7 +5,7 @@ from shared.config import Settings class BacktestConfig(Settings): backtest_initial_balance: float = 10000.0 - symbol: str = "BTCUSDT" + symbol: str = "AAPL" timeframe: str = "1h" strategy_name: str = "rsi_strategy" candle_limit: int = 500 diff --git a/services/backtester/src/backtester/engine.py b/services/backtester/src/backtester/engine.py index b03715d..fcf48f1 100644 --- a/services/backtester/src/backtester/engine.py +++ b/services/backtester/src/backtester/engine.py @@ -6,10 +6,9 @@ from dataclasses import dataclass, field from decimal import Decimal from typing import Protocol -from shared.models import Candle, Signal - from backtester.metrics import DetailedMetrics, TradeRecord, compute_detailed_metrics from backtester.simulator import OrderSimulator, SimulatedTrade +from shared.models import Candle, Signal class StrategyProtocol(Protocol): @@ -101,7 +100,7 @@ class BacktestEngine: final_balance = simulator.balance if candles: last_price = candles[-1].close - for symbol, qty in simulator.positions.items(): + for qty in simulator.positions.values(): if qty > Decimal("0"): final_balance += qty * last_price elif qty < Decimal("0"): diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py index a4cea76..dbde00b 100644 --- a/services/backtester/src/backtester/main.py +++ b/services/backtester/src/backtester/main.py @@ -17,11 +17,11 @@ _STRATEGIES_DIR = Path( if _STRATEGIES_DIR.parent not in [Path(p) for p in sys.path]: sys.path.insert(0, str(_STRATEGIES_DIR.parent)) -from shared.db import Database # noqa: E402 -from shared.models import Candle # noqa: E402 from backtester.config import BacktestConfig # noqa: E402 from backtester.engine import BacktestEngine # noqa: E402 from backtester.reporter import format_report # noqa: E402 +from shared.db import Database # noqa: E402 +from shared.models import Candle # noqa: E402 async def run_backtest() -> str: @@ -45,7 +45,7 @@ async def run_backtest() -> str: except Exception as exc: raise RuntimeError(f"Failed to load strategy '{config.strategy_name}': {exc}") from exc - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() try: rows = await db.get_candles(config.symbol, config.timeframe, config.candle_limit) diff --git a/services/backtester/src/backtester/metrics.py b/services/backtester/src/backtester/metrics.py index 239cb6f..c7b032b 100644 --- a/services/backtester/src/backtester/metrics.py +++ b/services/backtester/src/backtester/metrics.py @@ -266,7 +266,7 @@ def compute_detailed_metrics( largest_win=largest_win, largest_loss=largest_loss, avg_holding_period=avg_holding, - trade_pairs=[p for p in pairs], + trade_pairs=list(pairs), risk_free_rate=risk_free_rate, recovery_factor=recovery_factor, max_consecutive_losses=max_consec_losses, diff --git a/services/backtester/src/backtester/simulator.py b/services/backtester/src/backtester/simulator.py index 64c88dd..6bce18b 100644 --- a/services/backtester/src/backtester/simulator.py +++ b/services/backtester/src/backtester/simulator.py @@ -1,9 +1,8 @@ """Simulated order executor for backtesting.""" from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from typing import Optional from shared.models import OrderSide, Signal @@ -16,7 +15,7 @@ class SimulatedTrade: quantity: Decimal balance_after: Decimal fee: Decimal = Decimal("0") - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) @dataclass @@ -27,8 +26,8 @@ class OpenPosition: side: OrderSide # BUY = long, SELL = short entry_price: Decimal quantity: Decimal - stop_loss: Optional[Decimal] = None - take_profit: Optional[Decimal] = None + stop_loss: Decimal | None = None + take_profit: Decimal | None = None class OrderSimulator: @@ -70,7 +69,7 @@ class OrderSimulator: remaining: list[OpenPosition] = [] for pos in self.open_positions: triggered = False - exit_price: Optional[Decimal] = None + exit_price: Decimal | None = None if pos.side == OrderSide.BUY: # Long position if pos.stop_loss is not None and candle_low <= pos.stop_loss: @@ -125,12 +124,12 @@ class OrderSimulator: def execute( self, signal: Signal, - timestamp: Optional[datetime] = None, - stop_loss: Optional[Decimal] = None, - take_profit: Optional[Decimal] = None, + timestamp: datetime | None = None, + stop_loss: Decimal | None = None, + take_profit: Decimal | None = None, ) -> bool: """Execute a signal with slippage and fees. Returns True if accepted.""" - ts = timestamp or datetime.now(timezone.utc) + ts = timestamp or datetime.now(UTC) exec_price = self._apply_slippage(signal.price, signal.side) fee = self._calculate_fee(exec_price, signal.quantity) diff --git a/services/backtester/src/backtester/walk_forward.py b/services/backtester/src/backtester/walk_forward.py index c7b7fd8..720ad5e 100644 --- a/services/backtester/src/backtester/walk_forward.py +++ b/services/backtester/src/backtester/walk_forward.py @@ -1,11 +1,11 @@ """Walk-forward analysis for strategy parameter optimization.""" +from collections.abc import Callable from dataclasses import dataclass, field from decimal import Decimal -from typing import Callable -from shared.models import Candle from backtester.engine import BacktestEngine, BacktestResult, StrategyProtocol +from shared.models import Candle @dataclass diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py index 4794e63..f789831 100644 --- a/services/backtester/tests/test_engine.py +++ b/services/backtester/tests/test_engine.py @@ -1,20 +1,19 @@ """Tests for the BacktestEngine.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from unittest.mock import MagicMock - -from shared.models import Candle, Signal, OrderSide - from backtester.engine import BacktestEngine +from shared.models import Candle, OrderSide, Signal + def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle: return Candle( symbol=symbol, timeframe=timeframe, - open_time=datetime.now(timezone.utc), + open_time=datetime.now(UTC), open=Decimal(str(price)), high=Decimal(str(price * 1.01)), low=Decimal(str(price * 0.99)), diff --git a/services/backtester/tests/test_metrics.py b/services/backtester/tests/test_metrics.py index 55f5b6c..13e545e 100644 --- a/services/backtester/tests/test_metrics.py +++ b/services/backtester/tests/test_metrics.py @@ -1,17 +1,16 @@ """Tests for detailed backtest metrics.""" import math -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from decimal import Decimal import pytest - from backtester.metrics import TradeRecord, compute_detailed_metrics def _make_trade(side: str, price: str, minutes_offset: int = 0) -> TradeRecord: return TradeRecord( - time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(minutes=minutes_offset), + time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(minutes=minutes_offset), symbol="AAPL", side=side, price=Decimal(price), @@ -124,7 +123,7 @@ def test_consecutive_losses(): def test_risk_free_rate_affects_sharpe(): """Higher risk-free rate should lower Sharpe ratio.""" - base = datetime(2025, 1, 1, tzinfo=timezone.utc) + base = datetime(2025, 1, 1, tzinfo=UTC) trades = [ TradeRecord( time=base, symbol="AAPL", side="BUY", price=Decimal("100"), quantity=Decimal("1") @@ -184,7 +183,7 @@ def test_daily_returns_populated(): def test_fee_subtracted_from_pnl(): """Fees should be subtracted from trade PnL.""" - base = datetime(2025, 1, 1, tzinfo=timezone.utc) + base = datetime(2025, 1, 1, tzinfo=UTC) trades_with_fees = [ TradeRecord( time=base, diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py index 62e2cdb..f85594f 100644 --- a/services/backtester/tests/test_simulator.py +++ b/services/backtester/tests/test_simulator.py @@ -1,11 +1,12 @@ """Tests for the OrderSimulator.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from shared.models import OrderSide, Signal from backtester.simulator import OrderSimulator +from shared.models import OrderSide, Signal + def make_signal( symbol: str, @@ -135,7 +136,7 @@ def test_stop_loss_triggers(): signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("48000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("50500"), candle_low=Decimal("47500"), # below stop_loss @@ -153,7 +154,7 @@ def test_take_profit_triggers(): signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, take_profit=Decimal("55000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("56000"), # above take_profit candle_low=Decimal("50000"), @@ -171,7 +172,7 @@ def test_stop_not_triggered_within_range(): signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("48000"), take_profit=Decimal("55000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("52000"), candle_low=Decimal("49000"), @@ -212,7 +213,7 @@ def test_short_stop_loss(): signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("52000")) - ts = datetime(2025, 1, 1, tzinfo=timezone.utc) + ts = datetime(2025, 1, 1, tzinfo=UTC) closed = sim.check_stops( candle_high=Decimal("53000"), # above stop_loss candle_low=Decimal("49000"), diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py index 5ab2e7b..b1aa12c 100644 --- a/services/backtester/tests/test_walk_forward.py +++ b/services/backtester/tests/test_walk_forward.py @@ -1,18 +1,18 @@ """Tests for walk-forward analysis.""" import sys -from pathlib import Path +from datetime import UTC, datetime, timedelta from decimal import Decimal -from datetime import datetime, timedelta, timezone - +from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine")) -from shared.models import Candle from backtester.walk_forward import WalkForwardEngine, WalkForwardResult from strategies.rsi_strategy import RsiStrategy +from shared.models import Candle + def _generate_candles(n=100, base_price=100.0): candles = [] @@ -21,9 +21,9 @@ def _generate_candles(n=100, base_price=100.0): price = base_price + (i % 20) - 10 candles.append( Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1h", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i), + open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i), open=Decimal(str(price)), high=Decimal(str(price + 5)), low=Decimal(str(price - 5)), diff --git a/services/data-collector/Dockerfile b/services/data-collector/Dockerfile index 8cb8af4..4d154c5 100644 --- a/services/data-collector/Dockerfile +++ b/services/data-collector/Dockerfile @@ -1,8 +1,15 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/data-collector/ services/data-collector/ RUN pip install --no-cache-dir ./services/data-collector + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "data_collector.main"] diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py index b42b34c..2d44848 100644 --- a/services/data-collector/src/data_collector/main.py +++ b/services/data-collector/src/data_collector/main.py @@ -2,6 +2,9 @@ import asyncio +import aiohttp + +from data_collector.config import CollectorConfig from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database @@ -11,8 +14,7 @@ from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.models import Candle from shared.notifier import TelegramNotifier - -from data_collector.config import CollectorConfig +from shared.shutdown import GracefulShutdown # Health check port: base + 0 HEALTH_PORT_OFFSET = 0 @@ -45,8 +47,10 @@ async def fetch_latest_bars( volume=Decimal(str(bar["v"])), ) candles.append(candle) - except Exception as exc: - log.warning("fetch_bar_failed", symbol=symbol, error=str(exc)) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("fetch_bar_network_error", symbol=symbol, error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("fetch_bar_parse_error", symbol=symbol, error=str(exc)) return candles @@ -56,18 +60,18 @@ async def run() -> None: metrics = ServiceMetrics("data_collector") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -83,14 +87,17 @@ async def run() -> None: symbols = config.symbols timeframe = config.timeframes[0] if config.timeframes else "1Day" + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("starting", symbols=symbols, timeframe=timeframe, poll_interval=poll_interval) try: - while True: + while not shutdown.is_shutting_down: # Check if market is open try: is_open = await alpaca.is_market_open() - except Exception: + except (aiohttp.ClientError, ConnectionError, TimeoutError): is_open = False if is_open: @@ -109,7 +116,7 @@ async def run() -> None: await asyncio.sleep(poll_interval) except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "data-collector") raise finally: diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py index be85578..51f3aee 100644 --- a/services/data-collector/tests/test_storage.py +++ b/services/data-collector/tests/test_storage.py @@ -1,19 +1,20 @@ """Tests for storage module.""" -import pytest +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock -from shared.models import Candle +import pytest from data_collector.storage import CandleStorage +from shared.models import Candle + -def _make_candle(symbol: str = "BTCUSDT") -> Candle: +def _make_candle(symbol: str = "AAPL") -> Candle: return Candle( symbol=symbol, timeframe="1m", - open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC), open=Decimal("30000"), high=Decimal("30100"), low=Decimal("29900"), @@ -39,11 +40,11 @@ async def test_storage_saves_to_db_and_publishes(): mock_broker.publish.assert_called_once() stream_arg = mock_broker.publish.call_args[0][0] - assert stream_arg == "candles.BTCUSDT" + assert stream_arg == "candles.AAPL" data_arg = mock_broker.publish.call_args[0][1] assert data_arg["type"] == "CANDLE" - assert data_arg["data"]["symbol"] == "BTCUSDT" + assert data_arg["data"]["symbol"] == "AAPL" @pytest.mark.asyncio diff --git a/services/news-collector/Dockerfile b/services/news-collector/Dockerfile new file mode 100644 index 0000000..7accee2 --- /dev/null +++ b/services/news-collector/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.12-slim AS builder +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/news-collector/ services/news-collector/ +RUN pip install --no-cache-dir ./services/news-collector +RUN python -c "import nltk; nltk.download('vader_lexicon', download_dir='/usr/local/nltk_data')" + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin +COPY --from=builder /usr/local/nltk_data /usr/local/nltk_data +ENV PYTHONPATH=/app +USER appuser +CMD ["python", "-m", "news_collector.main"] diff --git a/services/news-collector/pyproject.toml b/services/news-collector/pyproject.toml new file mode 100644 index 0000000..6e62b70 --- /dev/null +++ b/services/news-collector/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "news-collector" +version = "0.1.0" +description = "News and sentiment data collector service" +requires-python = ">=3.12" +dependencies = ["trading-shared", "feedparser>=6.0,<7", "nltk>=3.8,<4", "aiohttp>=3.9,<4"] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "aioresponses>=0.7", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/news_collector"] diff --git a/services/news-collector/src/news_collector/__init__.py b/services/news-collector/src/news_collector/__init__.py new file mode 100644 index 0000000..5547af2 --- /dev/null +++ b/services/news-collector/src/news_collector/__init__.py @@ -0,0 +1 @@ +"""News collector service.""" diff --git a/services/news-collector/src/news_collector/collectors/__init__.py b/services/news-collector/src/news_collector/collectors/__init__.py new file mode 100644 index 0000000..5ef36a7 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/__init__.py @@ -0,0 +1 @@ +"""News collectors.""" diff --git a/services/news-collector/src/news_collector/collectors/base.py b/services/news-collector/src/news_collector/collectors/base.py new file mode 100644 index 0000000..bb43fd6 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/base.py @@ -0,0 +1,18 @@ +"""Base class for all news collectors.""" + +from abc import ABC, abstractmethod + +from shared.models import NewsItem + + +class BaseCollector(ABC): + name: str = "base" + poll_interval: int = 300 # seconds + + @abstractmethod + async def collect(self) -> list[NewsItem]: + """Collect news items from the source.""" + + @abstractmethod + async def is_available(self) -> bool: + """Check if this data source is accessible.""" diff --git a/services/news-collector/src/news_collector/collectors/fear_greed.py b/services/news-collector/src/news_collector/collectors/fear_greed.py new file mode 100644 index 0000000..42e8f88 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/fear_greed.py @@ -0,0 +1,62 @@ +"""CNN Fear & Greed Index collector.""" + +import logging +from dataclasses import dataclass + +import aiohttp + +from news_collector.collectors.base import BaseCollector + +logger = logging.getLogger(__name__) + +FEAR_GREED_URL = "https://production.dataviz.cnn.io/index/fearandgreed/graphdata" + + +@dataclass +class FearGreedResult: + fear_greed: int + fear_greed_label: str + + +class FearGreedCollector(BaseCollector): + name = "fear_greed" + poll_interval = 3600 # 1 hour + + async def is_available(self) -> bool: + return True + + async def _fetch_index(self) -> dict | None: + headers = {"User-Agent": "Mozilla/5.0"} + try: + async with aiohttp.ClientSession() as session: + async with session.get( + FEAR_GREED_URL, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status != 200: + return None + return await resp.json() + except Exception: + return None + + def _classify(self, score: int) -> str: + if score <= 20: + return "Extreme Fear" + if score <= 40: + return "Fear" + if score <= 60: + return "Neutral" + if score <= 80: + return "Greed" + return "Extreme Greed" + + async def collect(self) -> FearGreedResult | None: + data = await self._fetch_index() + if data is None: + return None + try: + fg = data["fear_and_greed"] + score = int(fg["score"]) + label = fg.get("rating", self._classify(score)) + return FearGreedResult(fear_greed=score, fear_greed_label=label) + except (KeyError, ValueError, TypeError): + return None diff --git a/services/news-collector/src/news_collector/collectors/fed.py b/services/news-collector/src/news_collector/collectors/fed.py new file mode 100644 index 0000000..52128e5 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/fed.py @@ -0,0 +1,119 @@ +"""Federal Reserve RSS collector with hawkish/dovish/neutral stance detection.""" + +import asyncio +import logging +from calendar import timegm +from datetime import UTC, datetime + +import feedparser +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_FED_RSS_URL = "https://www.federalreserve.gov/feeds/press_all.xml" + +_HAWKISH_KEYWORDS = [ + "rate hike", + "interest rate increase", + "tighten", + "tightening", + "inflation", + "hawkish", + "restrictive", + "raise rates", + "hike rates", +] +_DOVISH_KEYWORDS = [ + "rate cut", + "interest rate decrease", + "easing", + "ease", + "stimulus", + "dovish", + "accommodative", + "lower rates", + "cut rates", + "quantitative easing", +] + + +def _detect_stance(text: str) -> str: + lower = text.lower() + hawkish_hits = sum(1 for kw in _HAWKISH_KEYWORDS if kw in lower) + dovish_hits = sum(1 for kw in _DOVISH_KEYWORDS if kw in lower) + if hawkish_hits > dovish_hits: + return "hawkish" + if dovish_hits > hawkish_hits: + return "dovish" + return "neutral" + + +class FedCollector(BaseCollector): + name: str = "fed" + poll_interval: int = 3600 + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_fed_rss(self) -> list[dict]: + loop = asyncio.get_event_loop() + try: + parsed = await loop.run_in_executor(None, feedparser.parse, _FED_RSS_URL) + return parsed.get("entries", []) + except Exception as exc: + logger.error("Fed RSS fetch failed: %s", exc) + return [] + + def _parse_published(self, entry: dict) -> datetime: + published_parsed = entry.get("published_parsed") + if published_parsed: + try: + ts = timegm(published_parsed) + return datetime.fromtimestamp(ts, tz=UTC) + except Exception: + pass + return datetime.now(UTC) + + async def collect(self) -> list[NewsItem]: + try: + entries = await self._fetch_fed_rss() + except Exception as exc: + logger.error("Fed collector error: %s", exc) + return [] + + items: list[NewsItem] = [] + + for entry in entries: + title = entry.get("title", "").strip() + if not title: + continue + + summary = entry.get("summary", "") or "" + combined = f"{title} {summary}" + + sentiment = self._vader.polarity_scores(combined)["compound"] + stance = _detect_stance(combined) + published_at = self._parse_published(entry) + + items.append( + NewsItem( + source=self.name, + headline=title, + summary=summary or None, + url=entry.get("link") or None, + published_at=published_at, + symbols=[], + sentiment=sentiment, + category=NewsCategory.FED, + raw_data={"stance": stance, **dict(entry)}, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/finnhub.py b/services/news-collector/src/news_collector/collectors/finnhub.py new file mode 100644 index 0000000..67cb455 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/finnhub.py @@ -0,0 +1,88 @@ +"""Finnhub news collector with VADER sentiment analysis.""" + +import logging +from datetime import UTC, datetime + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_CATEGORY_KEYWORDS: dict[NewsCategory, list[str]] = { + NewsCategory.FED: ["fed", "fomc", "rate", "federal reserve"], + NewsCategory.POLICY: ["tariff", "trump", "regulation", "policy", "trade war"], + NewsCategory.EARNINGS: ["earnings", "revenue", "profit", "eps", "guidance", "quarter"], +} + + +def _categorize(text: str) -> NewsCategory: + lower = text.lower() + for category, keywords in _CATEGORY_KEYWORDS.items(): + if any(kw in lower for kw in keywords): + return category + return NewsCategory.MACRO + + +class FinnhubCollector(BaseCollector): + name: str = "finnhub" + poll_interval: int = 300 + + _BASE_URL = "https://finnhub.io/api/v1/news" + + def __init__(self, api_key: str) -> None: + self._api_key = api_key + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return bool(self._api_key) + + async def _fetch_news(self) -> list[dict]: + url = f"{self._BASE_URL}?category=general&token={self._api_key}" + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + return await resp.json() + + async def collect(self) -> list[NewsItem]: + try: + raw_items = await self._fetch_news() + except Exception as exc: + logger.error("Finnhub fetch failed: %s", exc) + return [] + + items: list[NewsItem] = [] + for article in raw_items: + headline = article.get("headline", "") + summary = article.get("summary", "") + combined = f"{headline} {summary}" + + sentiment_scores = self._vader.polarity_scores(combined) + sentiment = sentiment_scores["compound"] + + ts = article.get("datetime", 0) + published_at = datetime.fromtimestamp(ts, tz=UTC) + + related = article.get("related", "") + symbols = [t.strip() for t in related.split(",") if t.strip()] if related else [] + + category = _categorize(combined) + + items.append( + NewsItem( + source=self.name, + headline=headline, + summary=summary or None, + url=article.get("url") or None, + published_at=published_at, + symbols=symbols, + sentiment=sentiment, + category=category, + raw_data=article, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/reddit.py b/services/news-collector/src/news_collector/collectors/reddit.py new file mode 100644 index 0000000..4e9d6f5 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/reddit.py @@ -0,0 +1,97 @@ +"""Reddit social sentiment collector using JSON API with VADER sentiment analysis.""" + +import logging +import re +from datetime import UTC, datetime + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_SUBREDDITS = ["wallstreetbets", "stocks", "investing"] +_MIN_SCORE = 50 + +_TICKER_PATTERN = re.compile( + r"\b(AAPL|MSFT|GOOGL|GOOG|AMZN|TSLA|NVDA|META|BRK\.?[AB]|JPM|V|UNH|XOM|" + r"JNJ|WMT|MA|PG|HD|CVX|MRK|LLY|ABBV|PFE|BAC|KO|AVGO|COST|MCD|TMO|" + r"CSCO|ACN|ABT|DHR|TXN|NEE|NFLX|PM|UPS|RTX|HON|QCOM|AMGN|LOW|IBM|" + r"INTC|AMD|PYPL|GS|MS|BLK|SPGI|CAT|DE|GE|MMM|BA|F|GM|DIS|CMCSA)\b" +) + + +class RedditCollector(BaseCollector): + name: str = "reddit" + poll_interval: int = 900 + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_subreddit(self, subreddit: str) -> list[dict]: + url = f"https://www.reddit.com/r/{subreddit}/hot.json?limit=25" + headers = {"User-Agent": "TradingPlatform/1.0 (research@example.com)"} + try: + async with aiohttp.ClientSession() as session: + async with session.get( + url, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + return data.get("data", {}).get("children", []) + except Exception as exc: + logger.error("Reddit fetch failed for r/%s: %s", subreddit, exc) + return [] + + async def collect(self) -> list[NewsItem]: + seen_titles: set[str] = set() + items: list[NewsItem] = [] + + for subreddit in _SUBREDDITS: + try: + posts = await self._fetch_subreddit(subreddit) + except Exception as exc: + logger.error("Reddit collector error for r/%s: %s", subreddit, exc) + continue + + for post in posts: + post_data = post.get("data", {}) + title = post_data.get("title", "").strip() + score = post_data.get("score", 0) + + if not title or score < _MIN_SCORE: + continue + if title in seen_titles: + continue + seen_titles.add(title) + + selftext = post_data.get("selftext", "") or "" + combined = f"{title} {selftext}" + + sentiment = self._vader.polarity_scores(combined)["compound"] + symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined))) + + created_utc = post_data.get("created_utc", 0) + published_at = datetime.fromtimestamp(created_utc, tz=UTC) + + items.append( + NewsItem( + source=self.name, + headline=title, + summary=selftext or None, + url=post_data.get("url") or None, + published_at=published_at, + symbols=symbols, + sentiment=sentiment, + category=NewsCategory.SOCIAL, + raw_data=post_data, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/rss.py b/services/news-collector/src/news_collector/collectors/rss.py new file mode 100644 index 0000000..bca0e9f --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/rss.py @@ -0,0 +1,105 @@ +"""RSS news collector using feedparser with VADER sentiment analysis.""" + +import asyncio +import logging +import re +from datetime import UTC, datetime +from time import mktime + +import feedparser +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_DEFAULT_FEEDS = [ + "https://finance.yahoo.com/news/rssindex", + "https://news.google.com/rss/search?q=stock+market+finance&hl=en-US&gl=US&ceid=US:en", + "https://feeds.marketwatch.com/marketwatch/topstories/", +] + +_TICKER_PATTERN = re.compile( + r"\b(AAPL|MSFT|GOOGL|GOOG|AMZN|TSLA|NVDA|META|BRK\.?[AB]|JPM|V|UNH|XOM|" + r"JNJ|WMT|MA|PG|HD|CVX|MRK|LLY|ABBV|PFE|BAC|KO|AVGO|COST|MCD|TMO|" + r"CSCO|ACN|ABT|DHR|TXN|NEE|NFLX|PM|UPS|RTX|HON|QCOM|AMGN|LOW|IBM|" + r"INTC|AMD|PYPL|GS|MS|BLK|SPGI|CAT|DE|GE|MMM|BA|F|GM|DIS|CMCSA)\b" +) + + +class RSSCollector(BaseCollector): + name: str = "rss" + poll_interval: int = 600 + + def __init__(self, feeds: list[str] | None = None) -> None: + self._feeds = feeds if feeds is not None else _DEFAULT_FEEDS + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_feeds(self) -> list[dict]: + loop = asyncio.get_event_loop() + results = [] + for url in self._feeds: + try: + parsed = await loop.run_in_executor(None, feedparser.parse, url) + results.append(parsed) + except Exception as exc: + logger.error("RSS fetch failed for %s: %s", url, exc) + return results + + def _parse_published(self, entry: dict) -> datetime: + parsed_time = entry.get("published_parsed") + if parsed_time: + try: + ts = mktime(parsed_time) + return datetime.fromtimestamp(ts, tz=UTC) + except Exception: + pass + return datetime.now(UTC) + + async def collect(self) -> list[NewsItem]: + try: + feeds = await self._fetch_feeds() + except Exception as exc: + logger.error("RSS collector error: %s", exc) + return [] + + seen_titles: set[str] = set() + items: list[NewsItem] = [] + + for feed in feeds: + for entry in feed.get("entries", []): + title = entry.get("title", "").strip() + if not title or title in seen_titles: + continue + seen_titles.add(title) + + summary = entry.get("summary", "") or "" + combined = f"{title} {summary}" + + sentiment_scores = self._vader.polarity_scores(combined) + sentiment = sentiment_scores["compound"] + + symbols = list(dict.fromkeys(_TICKER_PATTERN.findall(combined))) + + published_at = self._parse_published(entry) + + items.append( + NewsItem( + source=self.name, + headline=title, + summary=summary or None, + url=entry.get("link") or None, + published_at=published_at, + symbols=symbols, + sentiment=sentiment, + category=NewsCategory.MACRO, + raw_data=dict(entry), + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/sec_edgar.py b/services/news-collector/src/news_collector/collectors/sec_edgar.py new file mode 100644 index 0000000..d88518f --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/sec_edgar.py @@ -0,0 +1,98 @@ +"""SEC EDGAR filing collector (free, no API key required).""" + +import logging +from datetime import UTC, datetime + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from news_collector.collectors.base import BaseCollector +from shared.models import NewsCategory, NewsItem + +logger = logging.getLogger(__name__) + +TRACKED_CIKS = { + "0000320193": "AAPL", + "0000789019": "MSFT", + "0001652044": "GOOGL", + "0001018724": "AMZN", + "0001318605": "TSLA", + "0001045810": "NVDA", + "0001326801": "META", + "0000019617": "JPM", + "0000078003": "PFE", + "0000021344": "KO", +} + +SEC_USER_AGENT = "TradingPlatform research@example.com" + + +class SecEdgarCollector(BaseCollector): + name = "sec_edgar" + poll_interval = 1800 # 30 minutes + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_recent_filings(self) -> list[dict]: + results = [] + headers = {"User-Agent": SEC_USER_AGENT} + async with aiohttp.ClientSession() as session: + for cik, ticker in TRACKED_CIKS.items(): + try: + url = f"https://data.sec.gov/submissions/CIK{cik}.json" + async with session.get( + url, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + data["tickers"] = [{"ticker": ticker}] + results.append(data) + except Exception as exc: + logger.warning("sec_fetch_failed", cik=cik, error=str(exc)) + return results + + async def collect(self) -> list[NewsItem]: + filings_data = await self._fetch_recent_filings() + items = [] + today = datetime.now(UTC).strftime("%Y-%m-%d") + + for company_data in filings_data: + tickers = [t["ticker"] for t in company_data.get("tickers", [])] + company_name = company_data.get("name", "Unknown") + recent = company_data.get("filings", {}).get("recent", {}) + + forms = recent.get("form", []) + dates = recent.get("filingDate", []) + descriptions = recent.get("primaryDocDescription", []) + accessions = recent.get("accessionNumber", []) + + for i, form in enumerate(forms): + if form != "8-K": + continue + filing_date = dates[i] if i < len(dates) else "" + if filing_date != today: + continue + + desc = descriptions[i] if i < len(descriptions) else "8-K Filing" + accession = accessions[i] if i < len(accessions) else "" + headline = f"{company_name} ({', '.join(tickers)}): {form} - {desc}" + + items.append( + NewsItem( + source=self.name, + headline=headline, + summary=desc, + url=f"https://www.sec.gov/cgi-bin/browse-edgar?action=getcompany&accession={accession}", + published_at=datetime.strptime(filing_date, "%Y-%m-%d").replace(tzinfo=UTC), + symbols=tickers, + sentiment=self._vader.polarity_scores(headline)["compound"], + category=NewsCategory.FILING, + raw_data={"form": form, "accession": accession}, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/collectors/truth_social.py b/services/news-collector/src/news_collector/collectors/truth_social.py new file mode 100644 index 0000000..e2acd88 --- /dev/null +++ b/services/news-collector/src/news_collector/collectors/truth_social.py @@ -0,0 +1,86 @@ +"""Truth Social collector using Mastodon-compatible API with VADER sentiment analysis.""" + +import logging +import re +from datetime import UTC, datetime + +import aiohttp +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +from shared.models import NewsCategory, NewsItem + +from .base import BaseCollector + +logger = logging.getLogger(__name__) + +_TRUMP_ACCOUNT_ID = "107780257626128497" +_API_URL = f"https://truthsocial.com/api/v1/accounts/{_TRUMP_ACCOUNT_ID}/statuses" + +_HTML_TAG_PATTERN = re.compile(r"<[^>]+>") + + +def _strip_html(text: str) -> str: + return _HTML_TAG_PATTERN.sub("", text).strip() + + +class TruthSocialCollector(BaseCollector): + name: str = "truth_social" + poll_interval: int = 900 + + def __init__(self) -> None: + self._vader = SentimentIntensityAnalyzer() + + async def is_available(self) -> bool: + return True + + async def _fetch_posts(self) -> list[dict]: + headers = {"User-Agent": "TradingPlatform/1.0 (research@example.com)"} + try: + async with aiohttp.ClientSession() as session: + async with session.get( + _API_URL, headers=headers, timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + return await resp.json() + except Exception as exc: + logger.error("Truth Social fetch failed: %s", exc) + return [] + + async def collect(self) -> list[NewsItem]: + try: + posts = await self._fetch_posts() + except Exception as exc: + logger.error("Truth Social collector error: %s", exc) + return [] + + items: list[NewsItem] = [] + + for post in posts: + raw_content = post.get("content", "") or "" + content = _strip_html(raw_content) + if not content: + continue + + sentiment = self._vader.polarity_scores(content)["compound"] + + created_at_str = post.get("created_at", "") + try: + published_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + except Exception: + published_at = datetime.now(UTC) + + items.append( + NewsItem( + source=self.name, + headline=content[:200], + summary=content if len(content) > 200 else None, + url=post.get("url") or None, + published_at=published_at, + symbols=[], + sentiment=sentiment, + category=NewsCategory.POLICY, + raw_data=post, + ) + ) + + return items diff --git a/services/news-collector/src/news_collector/config.py b/services/news-collector/src/news_collector/config.py new file mode 100644 index 0000000..6e78eba --- /dev/null +++ b/services/news-collector/src/news_collector/config.py @@ -0,0 +1,7 @@ +"""News Collector configuration.""" + +from shared.config import Settings + + +class NewsCollectorConfig(Settings): + health_port: int = 8084 diff --git a/services/news-collector/src/news_collector/main.py b/services/news-collector/src/news_collector/main.py new file mode 100644 index 0000000..c39fa67 --- /dev/null +++ b/services/news-collector/src/news_collector/main.py @@ -0,0 +1,204 @@ +"""News Collector Service — fetches news from multiple sources and aggregates sentiment.""" + +import asyncio +from datetime import UTC, datetime + +import aiohttp + +from news_collector.collectors.fear_greed import FearGreedCollector +from news_collector.collectors.fed import FedCollector +from news_collector.collectors.finnhub import FinnhubCollector +from news_collector.collectors.reddit import RedditCollector +from news_collector.collectors.rss import RSSCollector +from news_collector.collectors.sec_edgar import SecEdgarCollector +from news_collector.collectors.truth_social import TruthSocialCollector +from news_collector.config import NewsCollectorConfig +from shared.broker import RedisBroker +from shared.db import Database +from shared.events import NewsEvent +from shared.healthcheck import HealthCheckServer +from shared.logging import setup_logging +from shared.metrics import ServiceMetrics +from shared.models import NewsItem +from shared.notifier import TelegramNotifier +from shared.sentiment import SentimentAggregator +from shared.sentiment_models import MarketSentiment +from shared.shutdown import GracefulShutdown + + +async def run_collector_once(collector, db: Database, broker: RedisBroker) -> int: + """Run a single collector, store results in DB, publish to Redis. + + Returns the number of items collected. + """ + items: list[NewsItem] = await collector.collect() + count = 0 + for item in items: + await db.insert_news_item(item) + event = NewsEvent(data=item) + stream = f"news.{item.category.value}" + await broker.publish(stream, event.to_dict()) + count += 1 + return count + + +async def run_collector_loop(collector, db: Database, broker: RedisBroker, log) -> None: + """Run a collector repeatedly on its configured poll_interval.""" + while True: + try: + count = await run_collector_once(collector, db, broker) + log.info( + "collector_ran", + collector=collector.name, + count=count, + ) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning( + "collector_network_error", + collector=collector.name, + error=str(exc), + ) + except (ValueError, KeyError, TypeError) as exc: + log.warning( + "collector_parse_error", + collector=collector.name, + error=str(exc), + ) + await asyncio.sleep(collector.poll_interval) + + +async def run_fear_greed_loop(collector: FearGreedCollector, db: Database, log) -> None: + """Fetch Fear & Greed index on its interval and update MarketSentiment in DB.""" + while True: + try: + result = await collector.collect() + if result is not None: + ms = MarketSentiment( + fear_greed=result.fear_greed, + fear_greed_label=result.fear_greed_label, + vix=None, + fed_stance="neutral", + market_regime=_determine_regime(result.fear_greed, None), + updated_at=datetime.now(UTC), + ) + await db.upsert_market_sentiment(ms) + log.info( + "fear_greed_updated", + value=result.fear_greed, + label=result.fear_greed_label, + ) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("fear_greed_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("fear_greed_parse_error", error=str(exc)) + await asyncio.sleep(collector.poll_interval) + + +async def run_aggregator_loop(db: Database, interval: int, log) -> None: + """Run SentimentAggregator every interval seconds and persist scores.""" + aggregator = SentimentAggregator() + while True: + await asyncio.sleep(interval) + try: + now = datetime.now(UTC) + news_items = await db.get_recent_news(hours=24) + scores = aggregator.aggregate(news_items, now) + for score in scores.values(): + await db.upsert_symbol_score(score) + log.info("aggregation_complete", symbols=len(scores)) + except (ConnectionError, TimeoutError) as exc: + log.warning("aggregator_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("aggregator_parse_error", error=str(exc)) + + +def _determine_regime(fear_greed: int, vix: float | None) -> str: + """Classify market regime from fear/greed index and optional VIX.""" + aggregator = SentimentAggregator() + return aggregator.determine_regime(fear_greed, vix) + + +async def run() -> None: + config = NewsCollectorConfig() + log = setup_logging("news-collector", config.log_level, config.log_format) + metrics = ServiceMetrics("news_collector") + + notifier = TelegramNotifier( + bot_token=config.telegram_bot_token.get_secret_value(), + chat_id=config.telegram_chat_id, + ) + + db = Database(config.database_url.get_secret_value()) + await db.connect() + + broker = RedisBroker(config.redis_url.get_secret_value()) + + health = HealthCheckServer( + "news-collector", + port=config.health_port, + auth_token=config.metrics_auth_token, + ) + await health.start() + metrics.service_up.labels(service="news-collector").set(1) + + # Build collectors + finnhub = FinnhubCollector(api_key=config.finnhub_api_key.get_secret_value()) + rss = RSSCollector() + sec = SecEdgarCollector() + truth = TruthSocialCollector() + reddit = RedditCollector() + fear_greed = FearGreedCollector() + fed = FedCollector() + + news_collectors = [finnhub, rss, sec, truth, reddit, fed] + + shutdown = GracefulShutdown() + shutdown.install_handlers() + + log.info( + "starting", + collectors=[c.name for c in news_collectors], + poll_interval=config.news_poll_interval, + aggregate_interval=config.sentiment_aggregate_interval, + ) + + try: + tasks = [ + asyncio.create_task( + run_collector_loop(collector, db, broker, log), + name=f"collector-{collector.name}", + ) + for collector in news_collectors + ] + tasks.append( + asyncio.create_task( + run_fear_greed_loop(fear_greed, db, log), + name="fear-greed-loop", + ) + ) + tasks.append( + asyncio.create_task( + run_aggregator_loop(db, config.sentiment_aggregate_interval, log), + name="aggregator-loop", + ) + ) + await shutdown.wait() + except Exception as exc: + log.error("fatal_error", error=str(exc), exc_info=True) + await notifier.send_error(str(exc), "news-collector") + raise + finally: + metrics.service_up.labels(service="news-collector").set(0) + for task in tasks: + task.cancel() + await notifier.close() + await broker.close() + await db.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/news-collector/tests/__init__.py b/services/news-collector/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/news-collector/tests/__init__.py diff --git a/services/news-collector/tests/test_fear_greed.py b/services/news-collector/tests/test_fear_greed.py new file mode 100644 index 0000000..e8bd8f0 --- /dev/null +++ b/services/news-collector/tests/test_fear_greed.py @@ -0,0 +1,49 @@ +"""Tests for CNN Fear & Greed Index collector.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from news_collector.collectors.fear_greed import FearGreedCollector + + +@pytest.fixture +def collector(): + return FearGreedCollector() + + +def test_collector_name(collector): + assert collector.name == "fear_greed" + assert collector.poll_interval == 3600 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_api_response(collector): + mock_data = { + "fear_and_greed": { + "score": 45.0, + "rating": "Fear", + "timestamp": "2026-04-02T12:00:00+00:00", + } + } + with patch.object(collector, "_fetch_index", new_callable=AsyncMock, return_value=mock_data): + result = await collector.collect() + assert result.fear_greed == 45 + assert result.fear_greed_label == "Fear" + + +async def test_collect_returns_none_on_failure(collector): + with patch.object(collector, "_fetch_index", new_callable=AsyncMock, return_value=None): + result = await collector.collect() + assert result is None + + +def test_classify_label(): + c = FearGreedCollector() + assert c._classify(10) == "Extreme Fear" + assert c._classify(30) == "Fear" + assert c._classify(50) == "Neutral" + assert c._classify(70) == "Greed" + assert c._classify(85) == "Extreme Greed" diff --git a/services/news-collector/tests/test_fed.py b/services/news-collector/tests/test_fed.py new file mode 100644 index 0000000..7f1c46c --- /dev/null +++ b/services/news-collector/tests/test_fed.py @@ -0,0 +1,38 @@ +"""Tests for Federal Reserve collector.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from news_collector.collectors.fed import FedCollector + + +@pytest.fixture +def collector(): + return FedCollector() + + +def test_collector_name(collector): + assert collector.name == "fed" + assert collector.poll_interval == 3600 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_rss(collector): + mock_entries = [ + { + "title": "Federal Reserve issues FOMC statement", + "link": "https://www.federalreserve.gov/newsevents/pressreleases/monetary20260402a.htm", + "published_parsed": (2026, 4, 2, 14, 0, 0, 0, 0, 0), + "summary": "The Federal Open Market Committee decided to maintain the target range...", + }, + ] + with patch.object( + collector, "_fetch_fed_rss", new_callable=AsyncMock, return_value=mock_entries + ): + items = await collector.collect() + assert len(items) == 1 + assert items[0].source == "fed" + assert items[0].category.value == "fed" diff --git a/services/news-collector/tests/test_finnhub.py b/services/news-collector/tests/test_finnhub.py new file mode 100644 index 0000000..3af65b8 --- /dev/null +++ b/services/news-collector/tests/test_finnhub.py @@ -0,0 +1,67 @@ +"""Tests for Finnhub news collector.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from news_collector.collectors.finnhub import FinnhubCollector + + +@pytest.fixture +def collector(): + return FinnhubCollector(api_key="test_key") + + +def test_collector_name(collector): + assert collector.name == "finnhub" + assert collector.poll_interval == 300 + + +async def test_is_available_with_key(collector): + assert await collector.is_available() is True + + +async def test_is_available_without_key(): + c = FinnhubCollector(api_key="") + assert await c.is_available() is False + + +async def test_collect_parses_response(collector): + mock_response = [ + { + "category": "top news", + "datetime": 1711929600, + "headline": "AAPL beats earnings", + "id": 12345, + "related": "AAPL", + "source": "MarketWatch", + "summary": "Apple reported better than expected...", + "url": "https://example.com/article", + }, + { + "category": "top news", + "datetime": 1711929000, + "headline": "Fed holds rates steady", + "id": 12346, + "related": "", + "source": "Reuters", + "summary": "The Federal Reserve...", + "url": "https://example.com/fed", + }, + ] + + with patch.object(collector, "_fetch_news", new_callable=AsyncMock, return_value=mock_response): + items = await collector.collect() + + assert len(items) == 2 + assert items[0].source == "finnhub" + assert items[0].headline == "AAPL beats earnings" + assert items[0].symbols == ["AAPL"] + assert items[0].url == "https://example.com/article" + assert isinstance(items[0].sentiment, float) + assert items[1].symbols == [] + + +async def test_collect_handles_empty_response(collector): + with patch.object(collector, "_fetch_news", new_callable=AsyncMock, return_value=[]): + items = await collector.collect() + assert items == [] diff --git a/services/news-collector/tests/test_main.py b/services/news-collector/tests/test_main.py new file mode 100644 index 0000000..f85569a --- /dev/null +++ b/services/news-collector/tests/test_main.py @@ -0,0 +1,41 @@ +"""Tests for news collector scheduler.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +from news_collector.main import run_collector_once + +from shared.models import NewsCategory, NewsItem + + +async def test_run_collector_once_stores_and_publishes(): + mock_item = NewsItem( + source="test", + headline="Test news", + published_at=datetime(2026, 4, 2, tzinfo=UTC), + sentiment=0.5, + category=NewsCategory.MACRO, + ) + mock_collector = MagicMock() + mock_collector.name = "test" + mock_collector.collect = AsyncMock(return_value=[mock_item]) + mock_db = MagicMock() + mock_db.insert_news_item = AsyncMock() + mock_broker = MagicMock() + mock_broker.publish = AsyncMock() + + count = await run_collector_once(mock_collector, mock_db, mock_broker) + assert count == 1 + mock_db.insert_news_item.assert_called_once_with(mock_item) + mock_broker.publish.assert_called_once() + + +async def test_run_collector_once_handles_empty(): + mock_collector = MagicMock() + mock_collector.name = "test" + mock_collector.collect = AsyncMock(return_value=[]) + mock_db = MagicMock() + mock_broker = MagicMock() + + count = await run_collector_once(mock_collector, mock_db, mock_broker) + assert count == 0 diff --git a/services/news-collector/tests/test_reddit.py b/services/news-collector/tests/test_reddit.py new file mode 100644 index 0000000..31b1dc1 --- /dev/null +++ b/services/news-collector/tests/test_reddit.py @@ -0,0 +1,64 @@ +"""Tests for Reddit collector.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from news_collector.collectors.reddit import RedditCollector + + +@pytest.fixture +def collector(): + return RedditCollector() + + +def test_collector_name(collector): + assert collector.name == "reddit" + assert collector.poll_interval == 900 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_posts(collector): + mock_posts = [ + { + "data": { + "title": "NVDA to the moon! AI demand is insane", + "selftext": "Just loaded up on NVDA calls", + "url": "https://reddit.com/r/wallstreetbets/123", + "created_utc": 1711929600, + "score": 500, + "num_comments": 200, + "subreddit": "wallstreetbets", + } + }, + ] + with patch.object( + collector, "_fetch_subreddit", new_callable=AsyncMock, return_value=mock_posts + ): + items = await collector.collect() + assert len(items) >= 1 + assert items[0].source == "reddit" + assert items[0].category.value == "social" + + +async def test_collect_filters_low_score(collector): + mock_posts = [ + { + "data": { + "title": "Random question", + "selftext": "", + "url": "https://reddit.com/456", + "created_utc": 1711929600, + "score": 3, + "num_comments": 1, + "subreddit": "stocks", + } + }, + ] + with patch.object( + collector, "_fetch_subreddit", new_callable=AsyncMock, return_value=mock_posts + ): + items = await collector.collect() + assert items == [] diff --git a/services/news-collector/tests/test_rss.py b/services/news-collector/tests/test_rss.py new file mode 100644 index 0000000..7242c75 --- /dev/null +++ b/services/news-collector/tests/test_rss.py @@ -0,0 +1,47 @@ +"""Tests for RSS news collector.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from news_collector.collectors.rss import RSSCollector + + +@pytest.fixture +def collector(): + return RSSCollector() + + +def test_collector_name(collector): + assert collector.name == "rss" + assert collector.poll_interval == 600 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_feed(collector): + mock_feed = { + "entries": [ + { + "title": "NVDA surges on AI demand", + "link": "https://example.com/nvda", + "published_parsed": (2026, 4, 2, 12, 0, 0, 0, 0, 0), + "summary": "Nvidia stock jumped 5%...", + }, + { + "title": "Markets rally on jobs data", + "link": "https://example.com/market", + "published_parsed": (2026, 4, 2, 11, 0, 0, 0, 0, 0), + "summary": "The S&P 500 rose...", + }, + ], + } + + with patch.object(collector, "_fetch_feeds", new_callable=AsyncMock, return_value=[mock_feed]): + items = await collector.collect() + + assert len(items) == 2 + assert items[0].source == "rss" + assert items[0].headline == "NVDA surges on AI demand" + assert isinstance(items[0].sentiment, float) diff --git a/services/news-collector/tests/test_sec_edgar.py b/services/news-collector/tests/test_sec_edgar.py new file mode 100644 index 0000000..b0faf18 --- /dev/null +++ b/services/news-collector/tests/test_sec_edgar.py @@ -0,0 +1,58 @@ +"""Tests for SEC EDGAR filing collector.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from news_collector.collectors.sec_edgar import SecEdgarCollector + + +@pytest.fixture +def collector(): + return SecEdgarCollector() + + +def test_collector_name(collector): + assert collector.name == "sec_edgar" + assert collector.poll_interval == 1800 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_filings(collector): + mock_response = { + "filings": { + "recent": { + "accessionNumber": ["0001234-26-000001"], + "filingDate": ["2026-04-02"], + "primaryDocument": ["filing.htm"], + "form": ["8-K"], + "primaryDocDescription": ["Current Report"], + } + }, + "tickers": [{"ticker": "AAPL"}], + "name": "Apple Inc", + } + + mock_datetime = MagicMock(spec=datetime) + mock_datetime.now.return_value = datetime(2026, 4, 2, tzinfo=UTC) + mock_datetime.strptime = datetime.strptime + + with patch.object( + collector, "_fetch_recent_filings", new_callable=AsyncMock, return_value=[mock_response] + ): + with patch("news_collector.collectors.sec_edgar.datetime", mock_datetime): + items = await collector.collect() + + assert len(items) == 1 + assert items[0].source == "sec_edgar" + assert items[0].category.value == "filing" + assert "AAPL" in items[0].symbols + + +async def test_collect_handles_empty(collector): + with patch.object(collector, "_fetch_recent_filings", new_callable=AsyncMock, return_value=[]): + items = await collector.collect() + assert items == [] diff --git a/services/news-collector/tests/test_truth_social.py b/services/news-collector/tests/test_truth_social.py new file mode 100644 index 0000000..52f1e46 --- /dev/null +++ b/services/news-collector/tests/test_truth_social.py @@ -0,0 +1,42 @@ +"""Tests for Truth Social collector.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from news_collector.collectors.truth_social import TruthSocialCollector + + +@pytest.fixture +def collector(): + return TruthSocialCollector() + + +def test_collector_name(collector): + assert collector.name == "truth_social" + assert collector.poll_interval == 900 + + +async def test_is_available(collector): + assert await collector.is_available() is True + + +async def test_collect_parses_posts(collector): + mock_posts = [ + { + "content": "<p>We are imposing 25% tariffs on all steel imports!</p>", + "created_at": "2026-04-02T12:00:00.000Z", + "url": "https://truthsocial.com/@realDonaldTrump/12345", + "id": "12345", + }, + ] + with patch.object(collector, "_fetch_posts", new_callable=AsyncMock, return_value=mock_posts): + items = await collector.collect() + assert len(items) == 1 + assert items[0].source == "truth_social" + assert items[0].category.value == "policy" + + +async def test_collect_handles_empty(collector): + with patch.object(collector, "_fetch_posts", new_callable=AsyncMock, return_value=[]): + items = await collector.collect() + assert items == [] diff --git a/services/order-executor/Dockerfile b/services/order-executor/Dockerfile index bc8b21c..376afec 100644 --- a/services/order-executor/Dockerfile +++ b/services/order-executor/Dockerfile @@ -1,8 +1,15 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/order-executor/ services/order-executor/ RUN pip install --no-cache-dir ./services/order-executor + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "order_executor.main"] diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py index a71e762..fd502cd 100644 --- a/services/order-executor/src/order_executor/executor.py +++ b/services/order-executor/src/order_executor/executor.py @@ -1,18 +1,18 @@ """Order execution logic.""" -import structlog -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from typing import Any, Optional +from typing import Any + +import structlog +from order_executor.risk_manager import RiskManager from shared.broker import RedisBroker from shared.db import Database from shared.events import OrderEvent from shared.models import Order, OrderStatus, OrderType, Signal from shared.notifier import TelegramNotifier -from order_executor.risk_manager import RiskManager - logger = structlog.get_logger() @@ -35,7 +35,7 @@ class OrderExecutor: self.notifier = notifier self.dry_run = dry_run - async def execute(self, signal: Signal) -> Optional[Order]: + async def execute(self, signal: Signal) -> Order | None: """Run risk checks and place an order for the given signal.""" # Fetch buying power from Alpaca balance = await self.exchange.get_buying_power() @@ -71,7 +71,7 @@ class OrderExecutor: if self.dry_run: order.status = OrderStatus.FILLED - order.filled_at = datetime.now(timezone.utc) + order.filled_at = datetime.now(UTC) logger.info( "order_filled_dry_run", side=str(order.side), @@ -87,7 +87,7 @@ class OrderExecutor: type="market", ) order.status = OrderStatus.FILLED - order.filled_at = datetime.now(timezone.utc) + order.filled_at = datetime.now(UTC) logger.info( "order_filled", side=str(order.side), diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py index 51ab286..99f88e1 100644 --- a/services/order-executor/src/order_executor/main.py +++ b/services/order-executor/src/order_executor/main.py @@ -3,6 +3,11 @@ import asyncio from decimal import Decimal +import aiohttp + +from order_executor.config import ExecutorConfig +from order_executor.executor import OrderExecutor +from order_executor.risk_manager import RiskManager from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database @@ -11,10 +16,7 @@ from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier - -from order_executor.config import ExecutorConfig -from order_executor.executor import OrderExecutor -from order_executor.risk_manager import RiskManager +from shared.shutdown import GracefulShutdown # Health check port: base + 2 HEALTH_PORT_OFFSET = 2 @@ -26,18 +28,18 @@ async def run() -> None: metrics = ServiceMetrics("order_executor") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -83,6 +85,9 @@ async def run() -> None: await broker.ensure_group(stream, GROUP) + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("started", stream=stream, dry_run=config.dry_run) try: @@ -94,10 +99,15 @@ async def run() -> None: if event.type == EventType.SIGNAL: await executor.execute(event.data) await broker.ack(stream, GROUP, msg_id) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("pending_network_error", error=str(exc), msg_id=msg_id) + except (ValueError, KeyError, TypeError) as exc: + log.warning("pending_parse_error", error=str(exc), msg_id=msg_id) + await broker.ack(stream, GROUP, msg_id) except Exception as exc: - log.error("pending_failed", error=str(exc), msg_id=msg_id) + log.error("pending_failed", error=str(exc), msg_id=msg_id, exc_info=True) - while True: + while not shutdown.is_shutting_down: messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000) for msg_id, msg in messages: try: @@ -110,8 +120,19 @@ async def run() -> None: service="order-executor", event_type="signal" ).inc() await broker.ack(stream, GROUP, msg_id) + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("process_network_error", error=str(exc)) + metrics.errors_total.labels( + service="order-executor", error_type="network" + ).inc() + except (ValueError, KeyError, TypeError) as exc: + log.warning("process_parse_error", error=str(exc)) + await broker.ack(stream, GROUP, msg_id) + metrics.errors_total.labels( + service="order-executor", error_type="validation" + ).inc() except Exception as exc: - log.error("process_failed", error=str(exc)) + log.error("process_failed", error=str(exc), exc_info=True) metrics.errors_total.labels( service="order-executor", error_type="processing" ).inc() diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py index 5a05746..811a862 100644 --- a/services/order-executor/src/order_executor/risk_manager.py +++ b/services/order-executor/src/order_executor/risk_manager.py @@ -1,12 +1,12 @@ """Risk management for order execution.""" +import math +from collections import deque from dataclasses import dataclass -from datetime import datetime, timezone, timedelta +from datetime import UTC, datetime, timedelta from decimal import Decimal -from collections import deque -import math -from shared.models import Signal, OrderSide, Position +from shared.models import OrderSide, Position, Signal @dataclass @@ -123,15 +123,13 @@ class RiskManager: else: self._consecutive_losses += 1 if self._consecutive_losses >= self._max_consecutive_losses: - self._paused_until = datetime.now(timezone.utc) + timedelta( - minutes=self._loss_pause_minutes - ) + self._paused_until = datetime.now(UTC) + timedelta(minutes=self._loss_pause_minutes) def is_paused(self) -> bool: """Check if trading is paused due to consecutive losses.""" if self._paused_until is None: return False - if datetime.now(timezone.utc) >= self._paused_until: + if datetime.now(UTC) >= self._paused_until: self._paused_until = None self._consecutive_losses = 0 return False @@ -233,9 +231,9 @@ class RiskManager: mean_a = sum(returns_a) / len(returns_a) mean_b = sum(returns_b) / len(returns_b) - cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len( - returns_a - ) + cov = sum( + (a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b, strict=True) + ) / len(returns_a) std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a)) std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b)) @@ -280,7 +278,11 @@ class RiskManager: min_len = min(len(r) for r in all_returns) portfolio_returns = [] for i in range(min_len): - pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i) + pr = sum( + w * r[-(min_len - i)] + for w, r in zip(weights, all_returns, strict=False) + if len(r) > i + ) portfolio_returns.append(pr) if not portfolio_returns: diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py index dd823d7..cda6b72 100644 --- a/services/order-executor/tests/test_executor.py +++ b/services/order-executor/tests/test_executor.py @@ -4,11 +4,11 @@ from decimal import Decimal from unittest.mock import AsyncMock, MagicMock import pytest - -from shared.models import OrderSide, OrderStatus, Signal from order_executor.executor import OrderExecutor from order_executor.risk_manager import RiskCheckResult, RiskManager +from shared.models import OrderSide, OrderStatus, Signal + def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal: return Signal( diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py index 00a9ab4..66e769c 100644 --- a/services/order-executor/tests/test_risk_manager.py +++ b/services/order-executor/tests/test_risk_manager.py @@ -2,12 +2,12 @@ from decimal import Decimal +from order_executor.risk_manager import RiskManager from shared.models import OrderSide, Position, Signal -from order_executor.risk_manager import RiskManager -def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "BTC/USDT") -> Signal: +def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal: return Signal( strategy="test", symbol=symbol, @@ -93,7 +93,7 @@ def test_risk_check_rejects_insufficient_balance(): def test_trailing_stop_set_and_trigger(): """Trailing stop should trigger when price drops below stop level.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) signal = make_signal(side=OrderSide.BUY, price="94", quantity="0.01") result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0")) @@ -104,10 +104,10 @@ def test_trailing_stop_set_and_trigger(): def test_trailing_stop_updates_highest_price(): """Trailing stop should track the highest price seen.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) # Price rises to 120 => stop at 114 - rm.update_price("BTC/USDT", Decimal("120")) + rm.update_price("AAPL", Decimal("120")) # Price at 115 is above stop (114), should be allowed signal = make_signal(side=OrderSide.BUY, price="115", quantity="0.01") @@ -124,7 +124,7 @@ def test_trailing_stop_updates_highest_price(): def test_trailing_stop_not_triggered_above_stop(): """Trailing stop should not trigger when price is above stop level.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) # Price at 96 is above stop (95), should be allowed signal = make_signal(side=OrderSide.BUY, price="96", quantity="0.01") @@ -140,11 +140,11 @@ def test_max_open_positions_check(): rm = make_risk_manager(max_open_positions=2) positions = { - "BTC/USDT": make_position("BTC/USDT", "1", "100", "100"), - "ETH/USDT": make_position("ETH/USDT", "10", "50", "50"), + "AAPL": make_position("AAPL", "1", "100", "100"), + "MSFT": make_position("MSFT", "10", "50", "50"), } - signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="SOL/USDT") + signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="TSLA") result = rm.check(signal, balance=Decimal("10000"), positions=positions, daily_pnl=Decimal("0")) assert result.allowed is False assert result.reason == "Max open positions reached" @@ -158,14 +158,14 @@ def test_volatility_calculation(): rm = make_risk_manager(volatility_lookback=5) # No history yet - assert rm.get_volatility("BTC/USDT") is None + assert rm.get_volatility("AAPL") is None # Feed prices prices = [100, 102, 98, 105, 101] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - vol = rm.get_volatility("BTC/USDT") + vol = rm.get_volatility("AAPL") assert vol is not None assert vol > 0 @@ -177,9 +177,9 @@ def test_position_size_with_volatility_scaling(): # Feed volatile prices prices = [100, 120, 80, 130, 70] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - size = rm.calculate_position_size("BTC/USDT", Decimal("10000")) + size = rm.calculate_position_size("AAPL", Decimal("10000")) base = Decimal("10000") * Decimal("0.1") # High volatility should reduce size below base @@ -192,9 +192,9 @@ def test_position_size_without_scaling(): prices = [100, 120, 80, 130, 70] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - size = rm.calculate_position_size("BTC/USDT", Decimal("10000")) + size = rm.calculate_position_size("AAPL", Decimal("10000")) base = Decimal("10000") * Decimal("0.1") assert size == base @@ -211,8 +211,8 @@ def test_portfolio_exposure_check_passes(): max_portfolio_exposure=0.8, ) positions = { - "BTCUSDT": Position( - symbol="BTCUSDT", + "AAPL": Position( + symbol="AAPL", quantity=Decimal("0.01"), avg_entry_price=Decimal("50000"), current_price=Decimal("50000"), @@ -230,8 +230,8 @@ def test_portfolio_exposure_check_rejects(): max_portfolio_exposure=0.3, ) positions = { - "BTCUSDT": Position( - symbol="BTCUSDT", + "AAPL": Position( + symbol="AAPL", quantity=Decimal("1"), avg_entry_price=Decimal("50000"), current_price=Decimal("50000"), @@ -263,10 +263,10 @@ def test_var_calculation(): daily_loss_limit_pct=Decimal("10"), ) for i in range(30): - rm.update_price("BTCUSDT", Decimal(str(100 + (i % 5) - 2))) + rm.update_price("AAPL", Decimal(str(100 + (i % 5) - 2))) positions = { - "BTCUSDT": Position( - symbol="BTCUSDT", + "AAPL": Position( + symbol="AAPL", quantity=Decimal("1"), avg_entry_price=Decimal("100"), current_price=Decimal("100"), @@ -357,7 +357,7 @@ def test_drawdown_check_rejects_in_check(): rm.update_balance(Decimal("10000")) signal = Signal( strategy="test", - symbol="BTC/USDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000"), quantity=Decimal("0.01"), diff --git a/services/portfolio-manager/Dockerfile b/services/portfolio-manager/Dockerfile index b1a7681..0fa3f35 100644 --- a/services/portfolio-manager/Dockerfile +++ b/services/portfolio-manager/Dockerfile @@ -1,8 +1,15 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/portfolio-manager/ services/portfolio-manager/ RUN pip install --no-cache-dir ./services/portfolio-manager + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "portfolio_manager.main"] diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py index a6823ae..f885aa8 100644 --- a/services/portfolio-manager/src/portfolio_manager/main.py +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -2,6 +2,10 @@ import asyncio +import sqlalchemy.exc + +from portfolio_manager.config import PortfolioConfig +from portfolio_manager.portfolio import PortfolioTracker from shared.broker import RedisBroker from shared.db import Database from shared.events import Event, OrderEvent @@ -9,9 +13,7 @@ from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier - -from portfolio_manager.config import PortfolioConfig -from portfolio_manager.portfolio import PortfolioTracker +from shared.shutdown import GracefulShutdown ORDERS_STREAM = "orders" @@ -51,8 +53,12 @@ async def snapshot_loop( while True: try: await save_snapshot(db, tracker, notifier, log) + except (sqlalchemy.exc.OperationalError, ConnectionError, TimeoutError) as exc: + log.warning("snapshot_db_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("snapshot_data_error", error=str(exc)) except Exception as exc: - log.error("snapshot_failed", error=str(exc)) + log.error("snapshot_failed", error=str(exc), exc_info=True) await asyncio.sleep(interval_hours * 3600) @@ -61,10 +67,10 @@ async def run() -> None: log = setup_logging("portfolio-manager", config.log_level, config.log_format) metrics = ServiceMetrics("portfolio_manager") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) tracker = PortfolioTracker() health = HealthCheckServer( @@ -76,13 +82,16 @@ async def run() -> None: await health.start() metrics.service_up.labels(service="portfolio-manager").set(1) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() snapshot_task = asyncio.create_task( snapshot_loop(db, tracker, notifier, config.snapshot_interval_hours, log) ) + shutdown = GracefulShutdown() + shutdown.install_handlers() + GROUP = "portfolio-manager" CONSUMER = "portfolio-1" log.info("service_started", stream=ORDERS_STREAM) @@ -108,12 +117,16 @@ async def run() -> None: service="portfolio-manager", event_type="order" ).inc() await broker.ack(ORDERS_STREAM, GROUP, msg_id) + except (ValueError, KeyError, TypeError) as exc: + log.warning("pending_parse_error", error=str(exc), msg_id=msg_id) + await broker.ack(ORDERS_STREAM, GROUP, msg_id) + metrics.errors_total.labels(service="portfolio-manager", error_type="validation").inc() except Exception as exc: - log.error("pending_process_failed", error=str(exc), msg_id=msg_id) + log.error("pending_process_failed", error=str(exc), msg_id=msg_id, exc_info=True) metrics.errors_total.labels(service="portfolio-manager", error_type="processing").inc() try: - while True: + while not shutdown.is_shutting_down: messages = await broker.read_group(ORDERS_STREAM, GROUP, CONSUMER, count=10, block=1000) for msg_id, msg in messages: try: @@ -134,13 +147,21 @@ async def run() -> None: service="portfolio-manager", event_type="order" ).inc() await broker.ack(ORDERS_STREAM, GROUP, msg_id) + except (ValueError, KeyError, TypeError) as exc: + log.warning("message_parse_error", error=str(exc), msg_id=msg_id) + await broker.ack(ORDERS_STREAM, GROUP, msg_id) + metrics.errors_total.labels( + service="portfolio-manager", error_type="validation" + ).inc() except Exception as exc: - log.exception("message_processing_failed", error=str(exc), msg_id=msg_id) + log.error( + "message_processing_failed", error=str(exc), msg_id=msg_id, exc_info=True + ) metrics.errors_total.labels( service="portfolio-manager", error_type="processing" ).inc() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "portfolio-manager") raise finally: diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py index 768e071..c8a6894 100644 --- a/services/portfolio-manager/tests/test_portfolio.py +++ b/services/portfolio-manager/tests/test_portfolio.py @@ -2,15 +2,16 @@ from decimal import Decimal -from shared.models import Order, OrderSide, OrderStatus, OrderType from portfolio_manager.portfolio import PortfolioTracker +from shared.models import Order, OrderSide, OrderStatus, OrderType + def make_order(side: OrderSide, price: str, quantity: str) -> Order: """Helper to create a filled Order.""" return Order( signal_id="test-signal", - symbol="BTC/USDT", + symbol="AAPL", side=side, type=OrderType.MARKET, price=Decimal(price), @@ -24,7 +25,7 @@ def test_portfolio_add_buy_order() -> None: order = make_order(OrderSide.BUY, "50000", "0.1") tracker.apply_order(order) - position = tracker.get_position("BTC/USDT") + position = tracker.get_position("AAPL") assert position is not None assert position.quantity == Decimal("0.1") assert position.avg_entry_price == Decimal("50000") @@ -35,7 +36,7 @@ def test_portfolio_add_multiple_buys() -> None: tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.1")) tracker.apply_order(make_order(OrderSide.BUY, "52000", "0.1")) - position = tracker.get_position("BTC/USDT") + position = tracker.get_position("AAPL") assert position is not None assert position.quantity == Decimal("0.2") assert position.avg_entry_price == Decimal("51000") @@ -46,7 +47,7 @@ def test_portfolio_sell_reduces_position() -> None: tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.2")) tracker.apply_order(make_order(OrderSide.SELL, "55000", "0.1")) - position = tracker.get_position("BTC/USDT") + position = tracker.get_position("AAPL") assert position is not None assert position.quantity == Decimal("0.1") assert position.avg_entry_price == Decimal("50000") @@ -54,7 +55,7 @@ def test_portfolio_sell_reduces_position() -> None: def test_portfolio_no_position_returns_none() -> None: tracker = PortfolioTracker() - position = tracker.get_position("ETH/USDT") + position = tracker.get_position("MSFT") assert position is None @@ -66,7 +67,7 @@ def test_realized_pnl_on_sell() -> None: tracker.apply_order( Order( signal_id="s1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000"), @@ -80,7 +81,7 @@ def test_realized_pnl_on_sell() -> None: tracker.apply_order( Order( signal_id="s2", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("55000"), @@ -98,7 +99,7 @@ def test_realized_pnl_on_loss() -> None: tracker.apply_order( Order( signal_id="s1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000"), @@ -109,7 +110,7 @@ def test_realized_pnl_on_loss() -> None: tracker.apply_order( Order( signal_id="s2", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("45000"), @@ -128,7 +129,7 @@ def test_realized_pnl_accumulates() -> None: tracker.apply_order( Order( signal_id="s1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000"), @@ -141,7 +142,7 @@ def test_realized_pnl_accumulates() -> None: tracker.apply_order( Order( signal_id="s2", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("55000"), @@ -154,7 +155,7 @@ def test_realized_pnl_accumulates() -> None: tracker.apply_order( Order( signal_id="s3", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("60000"), diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py index a464599..f2026e2 100644 --- a/services/portfolio-manager/tests/test_snapshot.py +++ b/services/portfolio-manager/tests/test_snapshot.py @@ -1,9 +1,10 @@ """Tests for save_snapshot in portfolio-manager.""" -import pytest from decimal import Decimal from unittest.mock import AsyncMock, MagicMock +import pytest + from shared.models import Position @@ -13,7 +14,7 @@ class TestSaveSnapshot: from portfolio_manager.main import save_snapshot pos = Position( - symbol="BTCUSDT", + symbol="AAPL", quantity=Decimal("0.5"), avg_entry_price=Decimal("50000"), current_price=Decimal("52000"), diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile index de635dc..f1484e9 100644 --- a/services/strategy-engine/Dockerfile +++ b/services/strategy-engine/Dockerfile @@ -1,9 +1,16 @@ -FROM python:3.12-slim +FROM python:3.12-slim AS builder WORKDIR /app COPY shared/ shared/ RUN pip install --no-cache-dir ./shared COPY services/strategy-engine/ services/strategy-engine/ RUN pip install --no-cache-dir ./services/strategy-engine + +FROM python:3.12-slim +RUN useradd -r -s /bin/false appuser +WORKDIR /app +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin COPY services/strategy-engine/strategies/ /app/strategies/ ENV PYTHONPATH=/app +USER appuser CMD ["python", "-m", "strategy_engine.main"] diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml index 4f5b6be..e4bfb12 100644 --- a/services/strategy-engine/pyproject.toml +++ b/services/strategy-engine/pyproject.toml @@ -3,11 +3,7 @@ name = "strategy-engine" version = "0.1.0" description = "Plugin-based strategy execution engine" requires-python = ">=3.12" -dependencies = [ - "pandas>=2.0", - "numpy>=1.20", - "trading-shared", -] +dependencies = ["pandas>=2.1,<3", "numpy>=1.26,<3", "trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index e3a49c2..9fd9c49 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -4,6 +4,6 @@ from shared.config import Settings class StrategyConfig(Settings): - symbols: list[str] = ["BTC/USDT"] + symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] timeframes: list[str] = ["1m"] strategy_params: dict = {} diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py index d401aee..4b2c468 100644 --- a/services/strategy-engine/src/strategy_engine/engine.py +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -2,11 +2,11 @@ import logging -from shared.broker import RedisBroker -from shared.events import CandleEvent, SignalEvent, Event - from strategies.base import BaseStrategy +from shared.broker import RedisBroker +from shared.events import CandleEvent, Event, SignalEvent + logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class StrategyEngine: try: event = Event.from_dict(raw) except Exception as exc: - logger.warning("Failed to parse event: %s – %s", raw, exc) + logger.warning("Failed to parse event: %s - %s", raw, exc) continue if not isinstance(event, CandleEvent): diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 30de528..3d73058 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -1,17 +1,25 @@ """Strategy Engine Service entry point.""" import asyncio +import zoneinfo +from datetime import datetime from pathlib import Path +import aiohttp + +from shared.alpaca import AlpacaClient from shared.broker import RedisBroker +from shared.db import Database from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier - +from shared.sentiment_models import MarketSentiment +from shared.shutdown import GracefulShutdown from strategy_engine.config import StrategyConfig from strategy_engine.engine import StrategyEngine from strategy_engine.plugin_loader import load_strategies +from strategy_engine.stock_selector import StockSelector # The strategies directory lives alongside the installed package STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies" @@ -30,23 +38,74 @@ async def process_symbol(engine: StrategyEngine, stream: str, log) -> None: last_id = await engine.process_once(stream, last_id) +async def run_stock_selector( + selector: StockSelector, + notifier: TelegramNotifier, + db: Database, + config: StrategyConfig, + log, +) -> None: + """Run the stock selector once per day at the configured time.""" + et = zoneinfo.ZoneInfo("America/New_York") + + while True: + now_et = datetime.now(et) + target_hour, target_min = map(int, config.selector_final_time.split(":")) + + if now_et.hour == target_hour and now_et.minute == target_min: + log.info("stock_selector_running") + try: + selections = await selector.select() + if selections: + ms_data = await db.get_latest_market_sentiment() + ms = None + if ms_data: + ms = MarketSentiment(**ms_data) + await notifier.send_stock_selection(selections, ms) + log.info("stock_selector_complete", picks=[s.symbol for s in selections]) + else: + log.info("stock_selector_no_picks") + except (aiohttp.ClientError, ConnectionError, TimeoutError) as exc: + log.warning("stock_selector_network_error", error=str(exc)) + except (ValueError, KeyError, TypeError) as exc: + log.warning("stock_selector_data_error", error=str(exc)) + except Exception as exc: + log.error("stock_selector_error", error=str(exc), exc_info=True) + await asyncio.sleep(120) # Sleep past this minute + else: + await asyncio.sleep(30) + + async def run() -> None: config = StrategyConfig() log = setup_logging("strategy-engine", config.log_level, config.log_format) metrics = ServiceMetrics("strategy_engine") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) + + db = Database(config.database_url.get_secret_value()) + await db.connect() + + alpaca = AlpacaClient( + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), + paper=config.alpaca_paper, + ) + strategies = load_strategies(STRATEGIES_DIR) for strategy in strategies: params = config.strategy_params.get(strategy.name, {}) strategy.configure(params) + shutdown = GracefulShutdown() + shutdown.install_handlers() + log.info("loaded_strategies", count=len(strategies), names=[s.name for s in strategies]) engine = StrategyEngine(broker=broker, strategies=strategies) @@ -67,9 +126,23 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) - await asyncio.gather(*tasks) + if config.anthropic_api_key.get_secret_value(): + selector = StockSelector( + db=db, + broker=broker, + alpaca=alpaca, + anthropic_api_key=config.anthropic_api_key.get_secret_value(), + anthropic_model=config.anthropic_model, + max_picks=config.selector_max_picks, + ) + tasks.append( + asyncio.create_task(run_stock_selector(selector, notifier, db, config, log)) + ) + log.info("stock_selector_enabled", time=config.selector_final_time) + + await shutdown.wait() except Exception as exc: - log.error("fatal_error", error=str(exc)) + log.error("fatal_error", error=str(exc), exc_info=True) await notifier.send_error(str(exc), "strategy-engine") raise finally: @@ -78,6 +151,8 @@ async def run() -> None: metrics.service_up.labels(service="strategy-engine").set(0) await notifier.close() await broker.close() + await alpaca.close() + await db.close() def main() -> None: diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py index 62e4160..57680db 100644 --- a/services/strategy-engine/src/strategy_engine/plugin_loader.py +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -5,7 +5,6 @@ import sys from pathlib import Path import yaml - from strategies.base import BaseStrategy diff --git a/services/strategy-engine/src/strategy_engine/stock_selector.py b/services/strategy-engine/src/strategy_engine/stock_selector.py new file mode 100644 index 0000000..8657b93 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/stock_selector.py @@ -0,0 +1,418 @@ +"""3-stage stock selector engine: sentiment → technical → LLM.""" + +import asyncio +import json +import logging +import re +from datetime import UTC, datetime + +import aiohttp + +from shared.alpaca import AlpacaClient +from shared.broker import RedisBroker +from shared.db import Database +from shared.models import OrderSide +from shared.sentiment_models import Candidate, MarketSentiment, SelectedStock + +logger = logging.getLogger(__name__) + +ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" + + +def _extract_json_array(text: str) -> list[dict] | None: + """Extract a JSON array from text that may contain markdown code blocks.""" + code_block = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) + if code_block: + raw = code_block.group(1) + else: + array_match = re.search(r"\[.*\]", text, re.DOTALL) + if array_match: + raw = array_match.group(0) + else: + raw = text.strip() + + try: + data = json.loads(raw) + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + return None + except (json.JSONDecodeError, TypeError): + return None + + +def _parse_llm_selections(text: str) -> list[SelectedStock]: + """Parse LLM response into SelectedStock list. + + Handles both bare JSON arrays and markdown code blocks. + Returns empty list on any parse error. + """ + items = _extract_json_array(text) + if items is None: + return [] + + selections = [] + for item in items: + try: + selection = SelectedStock( + symbol=item["symbol"], + side=OrderSide(item["side"]), + conviction=float(item["conviction"]), + reason=item.get("reason", ""), + key_news=item.get("key_news", []), + ) + selections.append(selection) + except (KeyError, ValueError) as e: + logger.warning("Skipping invalid selection item: %s", e) + return selections + + +class SentimentCandidateSource: + """Generates candidates from DB sentiment scores.""" + + def __init__(self, db: Database) -> None: + self._db = db + + async def get_candidates(self) -> list[Candidate]: + rows = await self._db.get_top_symbol_scores(limit=20) + candidates = [] + for row in rows: + composite = float(row.get("composite", 0)) + if composite == 0: + continue + candidates.append( + Candidate( + symbol=row["symbol"], + source="sentiment", + score=composite, + reason=f"composite={composite:.2f}, news_count={row.get('news_count', 0)}", + ) + ) + return candidates + + +class LLMCandidateSource: + """Generates candidates by asking Claude to analyze recent news.""" + + def __init__(self, db: Database, api_key: str, model: str) -> None: + self._db = db + self._api_key = api_key + self._model = model + + async def get_candidates(self, session: aiohttp.ClientSession | None = None) -> list[Candidate]: + news_items = await self._db.get_recent_news(hours=24) + if not news_items: + return [] + + headlines = [] + for item in news_items[:50]: # cap at 50 to stay within context + symbols = item.get("symbols", []) + sym_str = ", ".join(symbols) if symbols else "N/A" + headlines.append(f"[{sym_str}] {item['headline']}") + + prompt = ( + "You are a stock analyst. Given recent news headlines, identify the 5-10 most " + "actionable US stock tickers. Return ONLY a JSON array with objects having: " + "symbol (ticker), direction ('BUY' or 'SELL'), score (0-1), reason (brief).\n\n" + "Headlines:\n" + "\n".join(headlines) + ) + + own_session = session is None + if own_session: + session = aiohttp.ClientSession() + + try: + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM candidate source error %d: %s", resp.status, body) + return [] + data = await resp.json() + + content = data.get("content", []) + text = "" + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text += block.get("text", "") + + return self._parse_candidates(text) + except Exception as e: + logger.error("LLMCandidateSource error: %s", e) + return [] + finally: + if own_session: + await session.close() + + def _parse_candidates(self, text: str) -> list[Candidate]: + items = _extract_json_array(text) + if items is None: + return [] + + candidates = [] + for item in items: + try: + direction_str = item.get("direction", "BUY") + direction = OrderSide(direction_str) + except ValueError: + direction = None + candidates.append( + Candidate( + symbol=item["symbol"], + source="llm", + direction=direction, + score=float(item.get("score", 0.5)), + reason=item.get("reason", ""), + ) + ) + return candidates + + +def _compute_rsi(closes: list[float], period: int = 14) -> float: + """Compute RSI for the last data point.""" + if len(closes) < period + 1: + return 50.0 # neutral if insufficient data + + deltas = [closes[i] - closes[i - 1] for i in range(1, len(closes))] + gains = [d if d > 0 else 0.0 for d in deltas] + losses = [-d if d < 0 else 0.0 for d in deltas] + + avg_gain = sum(gains[:period]) / period + avg_loss = sum(losses[:period]) / period + + for i in range(period, len(deltas)): + avg_gain = (avg_gain * (period - 1) + gains[i]) / period + avg_loss = (avg_loss * (period - 1) + losses[i]) / period + + if avg_loss == 0: + return 100.0 + rs = avg_gain / avg_loss + return 100.0 - (100.0 / (1.0 + rs)) + + +class StockSelector: + """Orchestrates the 3-stage stock selection pipeline.""" + + def __init__( + self, + db: Database, + broker: RedisBroker, + alpaca: AlpacaClient, + anthropic_api_key: str, + anthropic_model: str = "claude-sonnet-4-20250514", + max_picks: int = 3, + ) -> None: + self._db = db + self._broker = broker + self._alpaca = alpaca + self._api_key = anthropic_api_key + self._model = anthropic_model + self._max_picks = max_picks + self._http_session: aiohttp.ClientSession | None = None + self._session_lock = asyncio.Lock() + + async def _ensure_session(self) -> aiohttp.ClientSession: + async with self._session_lock: + if self._http_session is None or self._http_session.closed: + self._http_session = aiohttp.ClientSession() + return self._http_session + + async def close(self) -> None: + if self._http_session and not self._http_session.closed: + await self._http_session.close() + + async def select(self) -> list[SelectedStock]: + """Run the full 3-stage pipeline and return selected stocks.""" + # Market gate: check sentiment + sentiment_data = await self._db.get_latest_market_sentiment() + if sentiment_data is None: + logger.warning("No market sentiment data; skipping selection") + return [] + + market_sentiment = MarketSentiment(**sentiment_data) + if market_sentiment.market_regime == "risk_off": + logger.info("Market is risk_off; skipping stock selection") + return [] + + # Stage 1: gather candidates from both sources + sentiment_source = SentimentCandidateSource(self._db) + llm_source = LLMCandidateSource(self._db, self._api_key, self._model) + + session = await self._ensure_session() + sentiment_candidates = await sentiment_source.get_candidates() + llm_candidates = await llm_source.get_candidates(session=session) + + candidates = self._merge_candidates(sentiment_candidates, llm_candidates) + if not candidates: + logger.info("No candidates found") + return [] + + # Stage 2: technical filter + filtered = await self._technical_filter(candidates) + if not filtered: + logger.info("All candidates filtered out by technical criteria") + return [] + + # Stage 3: LLM final selection + selections = await self._llm_final_select(filtered, market_sentiment) + + # Persist and publish + today = datetime.now(UTC).date() + sentiment_snapshot = { + "fear_greed": market_sentiment.fear_greed, + "market_regime": market_sentiment.market_regime, + "vix": market_sentiment.vix, + } + for stock in selections: + try: + await self._db.insert_stock_selection( + trade_date=today, + symbol=stock.symbol, + side=stock.side.value, + conviction=stock.conviction, + reason=stock.reason, + key_news=stock.key_news, + sentiment_snapshot=sentiment_snapshot, + ) + except Exception as e: + logger.error("Failed to persist selection for %s: %s", stock.symbol, e) + + try: + await self._broker.publish( + "selected_stocks", + { + "symbol": stock.symbol, + "side": stock.side.value, + "conviction": stock.conviction, + "reason": stock.reason, + "key_news": stock.key_news, + "trade_date": str(today), + }, + ) + except Exception as e: + logger.error("Failed to publish selection for %s: %s", stock.symbol, e) + + return selections + + def _merge_candidates( + self, sentiment: list[Candidate], llm: list[Candidate] + ) -> list[Candidate]: + """Deduplicate candidates by symbol, keeping the higher score.""" + by_symbol: dict[str, Candidate] = {} + for c in sentiment + llm: + existing = by_symbol.get(c.symbol) + if existing is None or c.score > existing.score: + by_symbol[c.symbol] = c + return sorted(by_symbol.values(), key=lambda c: c.score, reverse=True) + + async def _technical_filter(self, candidates: list[Candidate]) -> list[Candidate]: + """Filter candidates using RSI, EMA20, and volume criteria.""" + passed = [] + for candidate in candidates: + try: + bars = await self._alpaca.get_bars(candidate.symbol, timeframe="1Day", limit=60) + if len(bars) < 21: + logger.debug("Insufficient bars for %s", candidate.symbol) + continue + + closes = [float(b["c"]) for b in bars] + volumes = [float(b["v"]) for b in bars] + + rsi = _compute_rsi(closes) + if not (30 <= rsi <= 70): + logger.debug("%s RSI=%.1f outside 30-70", candidate.symbol, rsi) + continue + + ema20 = sum(closes[-20:]) / 20 # simple approximation + current_price = closes[-1] + if current_price <= ema20: + logger.debug( + "%s price %.2f <= EMA20 %.2f", candidate.symbol, current_price, ema20 + ) + continue + + avg_volume = sum(volumes[:-1]) / max(len(volumes) - 1, 1) + current_volume = volumes[-1] + if current_volume <= 0.5 * avg_volume: + logger.debug( + "%s volume %.0f <= 50%% avg %.0f", + candidate.symbol, + current_volume, + avg_volume, + ) + continue + + passed.append(candidate) + except Exception as e: + logger.warning("Technical filter error for %s: %s", candidate.symbol, e) + + return passed + + async def _llm_final_select( + self, candidates: list[Candidate], market_sentiment: MarketSentiment + ) -> list[SelectedStock]: + """Ask Claude to pick 2-3 stocks with rationale.""" + candidate_lines = [ + f"- {c.symbol} (source={c.source}, score={c.score:.2f}, reason={c.reason})" + for c in candidates + ] + market_context = ( + f"Fear/Greed: {market_sentiment.fear_greed} ({market_sentiment.fear_greed_label}), " + f"VIX: {market_sentiment.vix}, " + f"Fed stance: {market_sentiment.fed_stance}, " + f"Regime: {market_sentiment.market_regime}" + ) + + prompt = ( + f"You are a portfolio manager. Select 2-3 stocks for today's session.\n\n" + f"Market context: {market_context}\n\n" + f"Candidates (already passed technical filters):\n" + + "\n".join(candidate_lines) + + "\n\n" + "Return ONLY a JSON array with objects having:\n" + " symbol, side ('BUY' or 'SELL'), conviction (0-1), reason (1-2 sentences), " + "key_news (list of 1-3 relevant headlines or facts)\n" + f"Select at most {self._max_picks} stocks." + ) + + try: + session = await self._ensure_session() + async with session.post( + ANTHROPIC_API_URL, + headers={ + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + json={ + "model": self._model, + "max_tokens": 1024, + "messages": [{"role": "user", "content": prompt}], + }, + ) as resp: + if resp.status != 200: + body = await resp.text() + logger.error("LLM final select error %d: %s", resp.status, body) + return [] + data = await resp.json() + + content = data.get("content", []) + text = "" + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text += block.get("text", "") + + return _parse_llm_selections(text)[: self._max_picks] + except Exception as e: + logger.error("LLM final select error: %s", e) + return [] diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py index d5be675..1d9d289 100644 --- a/services/strategy-engine/strategies/base.py +++ b/services/strategy-engine/strategies/base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from collections import deque from decimal import Decimal -from typing import Optional import pandas as pd @@ -102,7 +101,7 @@ class BaseStrategy(ABC): def _calculate_atr_stops( self, entry_price: Decimal, side: str - ) -> tuple[Optional[Decimal], Optional[Decimal]]: + ) -> tuple[Decimal | None, Decimal | None]: """Calculate ATR-based stop-loss and take-profit. Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data. @@ -131,7 +130,7 @@ class BaseStrategy(ABC): return sl, tp - def _apply_filters(self, signal: Signal) -> Optional[Signal]: + def _apply_filters(self, signal: Signal) -> Signal | None: """Apply all filters to a signal. Returns signal with SL/TP or None if filtered out.""" if signal is None: return None diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py index ebe7967..02ff09a 100644 --- a/services/strategy-engine/strategies/bollinger_strategy.py +++ b/services/strategy-engine/strategies/bollinger_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py index ba92485..f562918 100644 --- a/services/strategy-engine/strategies/combined_strategy.py +++ b/services/strategy-engine/strategies/combined_strategy.py @@ -2,7 +2,7 @@ from decimal import Decimal -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py index 68d0ba3..9c181f3 100644 --- a/services/strategy-engine/strategies/ema_crossover_strategy.py +++ b/services/strategy-engine/strategies/ema_crossover_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py index 283bfe5..491252e 100644 --- a/services/strategy-engine/strategies/grid_strategy.py +++ b/services/strategy-engine/strategies/grid_strategy.py @@ -1,9 +1,8 @@ from decimal import Decimal -from typing import Optional import numpy as np -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -17,7 +16,7 @@ class GridStrategy(BaseStrategy): self._grid_count: int = 5 self._quantity: Decimal = Decimal("0.01") self._grid_levels: list[float] = [] - self._last_zone: Optional[int] = None + self._last_zone: int | None = None self._exit_threshold_pct: float = 5.0 self._out_of_range: bool = False self._in_position: bool = False # Track if we have any grid positions diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py index 3c713e6..01637b7 100644 --- a/services/strategy-engine/strategies/indicators/__init__.py +++ b/services/strategy-engine/strategies/indicators/__init__.py @@ -1,21 +1,21 @@ """Reusable technical indicator functions.""" -from strategies.indicators.trend import ema, sma, macd, adx -from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels from strategies.indicators.momentum import rsi, stochastic -from strategies.indicators.volume import volume_sma, volume_ratio, obv +from strategies.indicators.trend import adx, ema, macd, sma +from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels +from strategies.indicators.volume import obv, volume_ratio, volume_sma __all__ = [ - "ema", - "sma", - "macd", "adx", "atr", "bollinger_bands", + "ema", "keltner_channels", + "macd", + "obv", "rsi", + "sma", "stochastic", - "volume_sma", "volume_ratio", - "obv", + "volume_sma", ] diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py index c479452..a82210b 100644 --- a/services/strategy-engine/strategies/indicators/momentum.py +++ b/services/strategy-engine/strategies/indicators/momentum.py @@ -1,7 +1,7 @@ """Momentum indicators: RSI, Stochastic.""" -import pandas as pd import numpy as np +import pandas as pd def rsi(closes: pd.Series, period: int = 14) -> pd.Series: diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py index c94a071..1085199 100644 --- a/services/strategy-engine/strategies/indicators/trend.py +++ b/services/strategy-engine/strategies/indicators/trend.py @@ -1,7 +1,7 @@ """Trend indicators: EMA, SMA, MACD, ADX.""" -import pandas as pd import numpy as np +import pandas as pd def sma(series: pd.Series, period: int) -> pd.Series: diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py index c16143e..da82f26 100644 --- a/services/strategy-engine/strategies/indicators/volatility.py +++ b/services/strategy-engine/strategies/indicators/volatility.py @@ -1,7 +1,7 @@ """Volatility indicators: ATR, Bollinger Bands, Keltner Channels.""" -import pandas as pd import numpy as np +import pandas as pd def atr( diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py index 502f1ce..d7c6471 100644 --- a/services/strategy-engine/strategies/indicators/volume.py +++ b/services/strategy-engine/strategies/indicators/volume.py @@ -1,7 +1,7 @@ """Volume indicators: Volume SMA, Volume Ratio, OBV.""" -import pandas as pd import numpy as np +import pandas as pd def volume_sma(volumes: pd.Series, period: int = 20) -> pd.Series: diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py index 356a42b..b5aea07 100644 --- a/services/strategy-engine/strategies/macd_strategy.py +++ b/services/strategy-engine/strategies/macd_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py index 7eaa59e..cbc8440 100644 --- a/services/strategy-engine/strategies/moc_strategy.py +++ b/services/strategy-engine/strategies/moc_strategy.py @@ -8,12 +8,12 @@ Rules: """ from collections import deque -from decimal import Decimal from datetime import datetime +from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py index 0646d8c..2df080d 100644 --- a/services/strategy-engine/strategies/rsi_strategy.py +++ b/services/strategy-engine/strategies/rsi_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import pandas as pd -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py index ef2ae14..67b5c23 100644 --- a/services/strategy-engine/strategies/volume_profile_strategy.py +++ b/services/strategy-engine/strategies/volume_profile_strategy.py @@ -3,7 +3,7 @@ from decimal import Decimal import numpy as np -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -137,7 +137,7 @@ class VolumeProfileStrategy(BaseStrategy): if result is None: return None - poc, va_low, va_high, hvn_levels, lvn_levels = result + poc, va_low, va_high, hvn_levels, _lvn_levels = result if close < va_low: self._was_below_va = True diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py index d64950e..4ee4952 100644 --- a/services/strategy-engine/strategies/vwap_strategy.py +++ b/services/strategy-engine/strategies/vwap_strategy.py @@ -1,7 +1,7 @@ from collections import deque from decimal import Decimal -from shared.models import Candle, Signal, OrderSide +from shared.models import Candle, OrderSide, Signal from strategies.base import BaseStrategy @@ -107,7 +107,7 @@ class VwapStrategy(BaseStrategy): # Standard deviation of (TP - VWAP) for bands std_dev = 0.0 if len(self._tp_values) >= 2: - diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)] + diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values, strict=True)] mean_diff = sum(diffs) / len(diffs) variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs) std_dev = variance**0.5 diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py index eb31b23..2b909ef 100644 --- a/services/strategy-engine/tests/conftest.py +++ b/services/strategy-engine/tests/conftest.py @@ -7,3 +7,8 @@ from pathlib import Path STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" if str(STRATEGIES_DIR) not in sys.path: sys.path.insert(0, str(STRATEGIES_DIR.parent)) + +# Ensure the worktree's strategy_engine src is preferred over any installed version +WORKTREE_SRC = Path(__file__).parent.parent / "src" +if str(WORKTREE_SRC) not in sys.path: + sys.path.insert(0, str(WORKTREE_SRC)) diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py index ae9ca05..66adec7 100644 --- a/services/strategy-engine/tests/test_base_filters.py +++ b/services/strategy-engine/tests/test_base_filters.py @@ -5,12 +5,13 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -from shared.models import Candle, Signal, OrderSide from strategies.base import BaseStrategy +from shared.models import Candle, OrderSide, Signal + class DummyStrategy(BaseStrategy): name = "dummy" @@ -45,7 +46,7 @@ def _candle(price=100.0, volume=10.0, high=None, low=None): return Candle( symbol="AAPL", timeframe="1h", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(h)), low=Decimal(str(lo)), diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py index 8261377..70ec66e 100644 --- a/services/strategy-engine/tests/test_bollinger_strategy.py +++ b/services/strategy-engine/tests/test_bollinger_strategy.py @@ -1,18 +1,18 @@ """Tests for the Bollinger Bands strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.bollinger_strategy import BollingerStrategy from shared.models import Candle, OrderSide -from strategies.bollinger_strategy import BollingerStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py index 8a4dc74..6a15250 100644 --- a/services/strategy-engine/tests/test_combined_strategy.py +++ b/services/strategy-engine/tests/test_combined_strategy.py @@ -5,13 +5,14 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime from decimal import Decimal -from datetime import datetime, timezone -import pytest -from shared.models import Candle, Signal, OrderSide -from strategies.combined_strategy import CombinedStrategy +import pytest from strategies.base import BaseStrategy +from strategies.combined_strategy import CombinedStrategy + +from shared.models import Candle, OrderSide, Signal class AlwaysBuyStrategy(BaseStrategy): @@ -74,7 +75,7 @@ def _candle(price=100.0): return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price + 10)), low=Decimal(str(price - 10)), diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py index 7028eb0..af2b587 100644 --- a/services/strategy-engine/tests/test_ema_crossover_strategy.py +++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py @@ -1,18 +1,18 @@ """Tests for the EMA Crossover strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.ema_crossover_strategy import EmaCrossoverStrategy from shared.models import Candle, OrderSide -from strategies.ema_crossover_strategy import EmaCrossoverStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py index 2623027..fa888b5 100644 --- a/services/strategy-engine/tests/test_engine.py +++ b/services/strategy-engine/tests/test_engine.py @@ -1,21 +1,21 @@ """Tests for the StrategyEngine.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal from unittest.mock import AsyncMock, MagicMock import pytest +from strategy_engine.engine import StrategyEngine -from shared.models import Candle, Signal, OrderSide from shared.events import CandleEvent -from strategy_engine.engine import StrategyEngine +from shared.models import Candle, OrderSide, Signal def make_candle_event() -> dict: candle = Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("50100"), low=Decimal("49900"), diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py index 878b900..f697012 100644 --- a/services/strategy-engine/tests/test_grid_strategy.py +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -1,18 +1,18 @@ """Tests for the Grid strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.grid_strategy import GridStrategy from shared.models import Candle, OrderSide -from strategies.grid_strategy import GridStrategy def make_candle(close: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py index 481569b..3147fc4 100644 --- a/services/strategy-engine/tests/test_indicators.py +++ b/services/strategy-engine/tests/test_indicators.py @@ -5,14 +5,13 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -import pandas as pd import numpy as np +import pandas as pd import pytest - -from strategies.indicators.trend import sma, ema, macd, adx -from strategies.indicators.volatility import atr, bollinger_bands from strategies.indicators.momentum import rsi, stochastic -from strategies.indicators.volume import volume_sma, volume_ratio, obv +from strategies.indicators.trend import adx, ema, macd, sma +from strategies.indicators.volatility import atr, bollinger_bands +from strategies.indicators.volume import obv, volume_ratio, volume_sma class TestTrend: diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py index 556fd4c..7fac16f 100644 --- a/services/strategy-engine/tests/test_macd_strategy.py +++ b/services/strategy-engine/tests/test_macd_strategy.py @@ -1,18 +1,18 @@ """Tests for the MACD strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.macd_strategy import MacdStrategy from shared.models import Candle, OrderSide -from strategies.macd_strategy import MacdStrategy def _candle(price: float) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(price)), high=Decimal(str(price)), low=Decimal(str(price)), diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py index 1928a28..076e846 100644 --- a/services/strategy-engine/tests/test_moc_strategy.py +++ b/services/strategy-engine/tests/test_moc_strategy.py @@ -5,19 +5,20 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal -from shared.models import Candle, OrderSide from strategies.moc_strategy import MocStrategy +from shared.models import Candle, OrderSide + def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None): op = open_price if open_price is not None else price - 1 # Default: bullish return Candle( symbol="AAPL", timeframe="5Min", - open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc), + open_time=datetime(2025, 1, day, hour, minute, tzinfo=UTC), open=Decimal(str(op)), high=Decimal(str(price + 1)), low=Decimal(str(min(op, price) - 1)), diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py index 671a9d3..922bfc2 100644 --- a/services/strategy-engine/tests/test_multi_symbol.py +++ b/services/strategy-engine/tests/test_multi_symbol.py @@ -9,11 +9,13 @@ import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from datetime import UTC, datetime +from decimal import Decimal + from strategy_engine.engine import StrategyEngine + from shared.events import CandleEvent from shared.models import Candle -from decimal import Decimal -from datetime import datetime, timezone @pytest.mark.asyncio @@ -24,7 +26,7 @@ async def test_engine_processes_multiple_streams(): candle_btc = Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), @@ -34,7 +36,7 @@ async def test_engine_processes_multiple_streams(): candle_eth = Candle( symbol="MSFT", timeframe="1m", - open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2025, 1, 1, tzinfo=UTC), open=Decimal("3000"), high=Decimal("3100"), low=Decimal("2900"), diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py index 5191fc3..7bd450f 100644 --- a/services/strategy-engine/tests/test_plugin_loader.py +++ b/services/strategy-engine/tests/test_plugin_loader.py @@ -2,10 +2,8 @@ from pathlib import Path - from strategy_engine.plugin_loader import load_strategies - STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py index 6d31fd5..6c74f0b 100644 --- a/services/strategy-engine/tests/test_rsi_strategy.py +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -1,18 +1,18 @@ """Tests for the RSI strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.rsi_strategy import RsiStrategy from shared.models import Candle, OrderSide -from strategies.rsi_strategy import RsiStrategy def make_candle(close: float, idx: int = 0) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), diff --git a/services/strategy-engine/tests/test_stock_selector.py b/services/strategy-engine/tests/test_stock_selector.py new file mode 100644 index 0000000..76b8541 --- /dev/null +++ b/services/strategy-engine/tests/test_stock_selector.py @@ -0,0 +1,111 @@ +"""Tests for stock selector engine.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +from strategy_engine.stock_selector import ( + SentimentCandidateSource, + StockSelector, + _extract_json_array, + _parse_llm_selections, +) + + +async def test_sentiment_candidate_source(): + mock_db = MagicMock() + mock_db.get_top_symbol_scores = AsyncMock( + return_value=[ + {"symbol": "AAPL", "composite": 0.8, "news_count": 5}, + {"symbol": "NVDA", "composite": 0.6, "news_count": 3}, + ] + ) + + source = SentimentCandidateSource(mock_db) + candidates = await source.get_candidates() + + assert len(candidates) == 2 + assert candidates[0].symbol == "AAPL" + assert candidates[0].source == "sentiment" + + +def test_parse_llm_selections_valid(): + llm_response = """ + [ + {"symbol": "NVDA", "side": "BUY", "conviction": 0.85, "reason": "AI demand", "key_news": ["NVDA beats earnings"]}, + {"symbol": "XOM", "side": "BUY", "conviction": 0.72, "reason": "Oil surge", "key_news": ["Oil prices up"]} + ] + """ + selections = _parse_llm_selections(llm_response) + assert len(selections) == 2 + assert selections[0].symbol == "NVDA" + assert selections[0].conviction == 0.85 + + +def test_parse_llm_selections_invalid(): + selections = _parse_llm_selections("not json") + assert selections == [] + + +def test_parse_llm_selections_with_markdown(): + llm_response = """ + Here are my picks: + ```json + [ + {"symbol": "TSLA", "side": "BUY", "conviction": 0.7, "reason": "Momentum", "key_news": ["Tesla rally"]} + ] + ``` + """ + selections = _parse_llm_selections(llm_response) + assert len(selections) == 1 + assert selections[0].symbol == "TSLA" + + +def test_extract_json_array_from_markdown(): + text = '```json\n[{"symbol": "AAPL", "score": 0.9}]\n```' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL", "score": 0.9}] + + +def test_extract_json_array_bare(): + text = '[{"symbol": "TSLA"}]' + result = _extract_json_array(text) + assert result == [{"symbol": "TSLA"}] + + +def test_extract_json_array_invalid(): + assert _extract_json_array("not json") is None + + +def test_extract_json_array_filters_non_dicts(): + text = '[{"symbol": "AAPL"}, "bad", 42]' + result = _extract_json_array(text) + assert result == [{"symbol": "AAPL"}] + + +async def test_selector_close(): + selector = StockSelector( + db=MagicMock(), broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + # No session yet - close should be safe + await selector.close() + assert selector._http_session is None + + +async def test_selector_blocks_on_risk_off(): + mock_db = MagicMock() + mock_db.get_latest_market_sentiment = AsyncMock( + return_value={ + "fear_greed": 15, + "fear_greed_label": "Extreme Fear", + "vix": 35.0, + "fed_stance": "neutral", + "market_regime": "risk_off", + "updated_at": datetime.now(UTC), + } + ) + + selector = StockSelector( + db=mock_db, broker=MagicMock(), alpaca=MagicMock(), anthropic_api_key="test" + ) + result = await selector.select() + assert result == [] diff --git a/services/strategy-engine/tests/test_strategy_validation.py b/services/strategy-engine/tests/test_strategy_validation.py index debab1f..0d9607a 100644 --- a/services/strategy-engine/tests/test_strategy_validation.py +++ b/services/strategy-engine/tests/test_strategy_validation.py @@ -1,13 +1,11 @@ import pytest - -from strategies.rsi_strategy import RsiStrategy -from strategies.macd_strategy import MacdStrategy from strategies.bollinger_strategy import BollingerStrategy from strategies.ema_crossover_strategy import EmaCrossoverStrategy from strategies.grid_strategy import GridStrategy -from strategies.vwap_strategy import VwapStrategy +from strategies.macd_strategy import MacdStrategy +from strategies.rsi_strategy import RsiStrategy from strategies.volume_profile_strategy import VolumeProfileStrategy - +from strategies.vwap_strategy import VwapStrategy # ── RSI ────────────────────────────────────────────────────────────────── diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py index 65ee2e8..f47898c 100644 --- a/services/strategy-engine/tests/test_volume_profile_strategy.py +++ b/services/strategy-engine/tests/test_volume_profile_strategy.py @@ -1,18 +1,18 @@ """Tests for the Volume Profile strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.volume_profile_strategy import VolumeProfileStrategy from shared.models import Candle, OrderSide -from strategies.volume_profile_strategy import VolumeProfileStrategy def make_candle(close: float, volume: float = 1.0) -> Candle: return Candle( symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=datetime(2024, 1, 1, tzinfo=UTC), open=Decimal(str(close)), high=Decimal(str(close)), low=Decimal(str(close)), @@ -134,13 +134,10 @@ def test_volume_profile_hvn_detection(): # Create a profile with very high volume at price ~100 and low volume elsewhere # Prices range from 90 to 110, heavy volume concentrated at 100 - candles_data = [] # Low volume at extremes - for p in [90, 91, 92, 109, 110]: - candles_data.append((p, 1.0)) + candles_data = [(p, 1.0) for p in [90, 91, 92, 109, 110]] # Very high volume around 100 - for _ in range(15): - candles_data.append((100, 100.0)) + candles_data.extend((100, 100.0) for _ in range(15)) for price, vol in candles_data: strategy.on_candle(make_candle(price, vol)) @@ -148,7 +145,7 @@ def test_volume_profile_hvn_detection(): # Access the internal method to verify HVN detection result = strategy._compute_value_area() assert result is not None - poc, va_low, va_high, hvn_levels, lvn_levels = result + _poc, _va_low, _va_high, hvn_levels, _lvn_levels = result # The bin containing price ~100 should have very high volume -> HVN assert len(hvn_levels) > 0 diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py index 2c34b01..078d0cf 100644 --- a/services/strategy-engine/tests/test_vwap_strategy.py +++ b/services/strategy-engine/tests/test_vwap_strategy.py @@ -1,11 +1,11 @@ """Tests for the VWAP strategy.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from decimal import Decimal +from strategies.vwap_strategy import VwapStrategy from shared.models import Candle, OrderSide -from strategies.vwap_strategy import VwapStrategy def make_candle( @@ -20,7 +20,7 @@ def make_candle( if low is None: low = close if open_time is None: - open_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + open_time = datetime(2024, 1, 1, tzinfo=UTC) return Candle( symbol="AAPL", timeframe="1m", @@ -111,11 +111,11 @@ def test_vwap_daily_reset(): """Candles from two different dates cause VWAP to reset.""" strategy = _configured_strategy() - day1 = datetime(2024, 1, 1, tzinfo=timezone.utc) - day2 = datetime(2024, 1, 2, tzinfo=timezone.utc) + day1 = datetime(2024, 1, 1, tzinfo=UTC) + day2 = datetime(2024, 1, 2, tzinfo=UTC) # Feed 35 candles on day 1 to build VWAP state - for i in range(35): + for _i in range(35): strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1)) # Verify state is built up |
