diff options
Diffstat (limited to 'services')
55 files changed, 2014 insertions, 523 deletions
diff --git a/services/api/tests/test_portfolio_router.py b/services/api/tests/test_portfolio_router.py index f2584ea..3bd1b2c 100644 --- a/services/api/tests/test_portfolio_router.py +++ b/services/api/tests/test_portfolio_router.py @@ -45,7 +45,7 @@ def test_get_positions_with_data(app, mock_db): app.state.db = db mock_row = MagicMock() - mock_row.symbol = "BTCUSDT" + mock_row.symbol = "AAPL" mock_row.quantity = Decimal("0.1") mock_row.avg_entry_price = Decimal("50000") mock_row.current_price = Decimal("55000") @@ -59,7 +59,7 @@ def test_get_positions_with_data(app, mock_db): assert response.status_code == 200 data = response.json() assert len(data) == 1 - assert data[0]["symbol"] == "BTCUSDT" + assert data[0]["symbol"] == "AAPL" def test_get_snapshots_empty(app, mock_db): diff --git a/services/backtester/src/backtester/config.py b/services/backtester/src/backtester/config.py index f7897da..57ee1fb 100644 --- a/services/backtester/src/backtester/config.py +++ b/services/backtester/src/backtester/config.py @@ -5,7 +5,7 @@ from shared.config import Settings class BacktestConfig(Settings): backtest_initial_balance: float = 10000.0 - symbol: str = "BTCUSDT" + symbol: str = "AAPL" timeframe: str = "1h" strategy_name: str = "rsi_strategy" candle_limit: int = 500 diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py index 743a43b..4794e63 100644 --- a/services/backtester/tests/test_engine.py +++ b/services/backtester/tests/test_engine.py @@ -23,14 +23,14 @@ def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle: ) -def make_candles(prices: list[float], symbol: str = "BTCUSDT") -> list[Candle]: +def make_candles(prices: list[float], symbol: str = "AAPL") -> list[Candle]: return [make_candle(symbol, p) for p in prices] def make_signal(side: OrderSide, price: str, quantity: str = "0.1") -> Signal: return Signal( strategy="test", - symbol="BTCUSDT", + symbol="AAPL", side=side, price=Decimal(price), quantity=Decimal(quantity), diff --git a/services/backtester/tests/test_metrics.py b/services/backtester/tests/test_metrics.py index 582309a..55f5b6c 100644 --- a/services/backtester/tests/test_metrics.py +++ b/services/backtester/tests/test_metrics.py @@ -12,7 +12,7 @@ from backtester.metrics import TradeRecord, compute_detailed_metrics def _make_trade(side: str, price: str, minutes_offset: int = 0) -> TradeRecord: return TradeRecord( time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(minutes=minutes_offset), - symbol="BTCUSDT", + symbol="AAPL", side=side, price=Decimal(price), quantity=Decimal("1"), @@ -127,39 +127,39 @@ def test_risk_free_rate_affects_sharpe(): base = datetime(2025, 1, 1, tzinfo=timezone.utc) trades = [ TradeRecord( - time=base, symbol="BTCUSDT", side="BUY", price=Decimal("100"), quantity=Decimal("1") + time=base, symbol="AAPL", side="BUY", price=Decimal("100"), quantity=Decimal("1") ), TradeRecord( time=base + timedelta(days=1), - symbol="BTCUSDT", + symbol="AAPL", side="SELL", price=Decimal("110"), quantity=Decimal("1"), ), TradeRecord( time=base + timedelta(days=2), - symbol="BTCUSDT", + symbol="AAPL", side="BUY", price=Decimal("105"), quantity=Decimal("1"), ), TradeRecord( time=base + timedelta(days=3), - symbol="BTCUSDT", + symbol="AAPL", side="SELL", price=Decimal("115"), quantity=Decimal("1"), ), TradeRecord( time=base + timedelta(days=4), - symbol="BTCUSDT", + symbol="AAPL", side="BUY", price=Decimal("110"), quantity=Decimal("1"), ), TradeRecord( time=base + timedelta(days=5), - symbol="BTCUSDT", + symbol="AAPL", side="SELL", price=Decimal("108"), quantity=Decimal("1"), diff --git a/services/backtester/tests/test_reporter.py b/services/backtester/tests/test_reporter.py index 2ea49c0..5199b68 100644 --- a/services/backtester/tests/test_reporter.py +++ b/services/backtester/tests/test_reporter.py @@ -32,7 +32,7 @@ def _make_result(with_detailed: bool = False) -> BacktestResult: ) return BacktestResult( strategy_name="sma_crossover", - symbol="BTCUSDT", + symbol="AAPL", total_trades=10, initial_balance=Decimal("10000"), final_balance=Decimal("11500"), @@ -48,7 +48,7 @@ def test_format_report_contains_key_metrics(): report = format_report(result) assert "sma_crossover" in report - assert "BTCUSDT" in report + assert "AAPL" in report assert "10000" in report assert "11500" in report assert "1500" in report @@ -89,7 +89,7 @@ def test_export_json(): data = json.loads(json_output) assert data["strategy_name"] == "sma_crossover" - assert data["symbol"] == "BTCUSDT" + assert data["symbol"] == "AAPL" assert "detailed" in data assert data["detailed"]["sharpe_ratio"] == 1.5 assert data["detailed"]["monthly_returns"]["2025-01"] == 500.0 diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py index a407c21..62e2cdb 100644 --- a/services/backtester/tests/test_simulator.py +++ b/services/backtester/tests/test_simulator.py @@ -36,20 +36,20 @@ def test_simulator_initial_balance(): def test_simulator_buy_reduces_balance(): sim = OrderSimulator(Decimal("10000")) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") result = sim.execute(signal) assert result is True assert sim.balance == Decimal("5000") - assert sim.positions["BTCUSDT"] == Decimal("0.1") + assert sim.positions["AAPL"] == Decimal("0.1") def test_simulator_sell_increases_balance(): sim = OrderSimulator(Decimal("10000")) - buy_signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + buy_signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(buy_signal) balance_after_buy = sim.balance - sell_signal = make_signal("BTCUSDT", OrderSide.SELL, "55000", "0.1") + sell_signal = make_signal("AAPL", OrderSide.SELL, "55000", "0.1") result = sim.execute(sell_signal) assert result is True assert sim.balance > balance_after_buy @@ -59,20 +59,20 @@ def test_simulator_sell_increases_balance(): def test_simulator_reject_buy_insufficient_balance(): sim = OrderSimulator(Decimal("100")) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") result = sim.execute(signal) assert result is False assert sim.balance == Decimal("100") - assert sim.positions.get("BTCUSDT", Decimal("0")) == Decimal("0") + assert sim.positions.get("AAPL", Decimal("0")) == Decimal("0") def test_simulator_trade_history(): sim = OrderSimulator(Decimal("10000")) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal) assert len(sim.trades) == 1 trade = sim.trades[0] - assert trade.symbol == "BTCUSDT" + assert trade.symbol == "AAPL" assert trade.side == OrderSide.BUY assert trade.price == Decimal("50000") assert trade.quantity == Decimal("0.1") @@ -86,7 +86,7 @@ def test_simulator_trade_history(): def test_slippage_on_buy(): """Buy price should increase by slippage_pct.""" sim = OrderSimulator(Decimal("100000"), slippage_pct=0.01) # 1% - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal) trade = sim.trades[0] expected_price = Decimal("50000") * Decimal("1.01") # 50500 @@ -97,10 +97,10 @@ def test_slippage_on_sell(): """Sell price should decrease by slippage_pct.""" sim = OrderSimulator(Decimal("100000"), slippage_pct=0.01) # Buy first (no slippage check here, just need a position) - buy = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + buy = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(buy) # Sell - sell = make_signal("BTCUSDT", OrderSide.SELL, "50000", "0.1") + sell = make_signal("AAPL", OrderSide.SELL, "50000", "0.1") sim.execute(sell) trade = sim.trades[1] expected_price = Decimal("50000") * Decimal("0.99") # 49500 @@ -116,7 +116,7 @@ def test_fee_deducted_from_balance(): """Fees should reduce balance beyond the raw cost.""" fee_pct = 0.001 # 0.1% sim = OrderSimulator(Decimal("100000"), taker_fee_pct=fee_pct) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal) # cost = 50000 * 0.1 = 5000, fee = 5000 * 0.001 = 5 expected_balance = Decimal("100000") - Decimal("5000") - Decimal("5") @@ -132,7 +132,7 @@ def test_fee_deducted_from_balance(): def test_stop_loss_triggers(): """Long position auto-closed when candle_low <= stop_loss.""" sim = OrderSimulator(Decimal("100000")) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("48000")) ts = datetime(2025, 1, 1, tzinfo=timezone.utc) @@ -150,7 +150,7 @@ def test_stop_loss_triggers(): def test_take_profit_triggers(): """Long position auto-closed when candle_high >= take_profit.""" sim = OrderSimulator(Decimal("100000")) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, take_profit=Decimal("55000")) ts = datetime(2025, 1, 1, tzinfo=timezone.utc) @@ -168,7 +168,7 @@ def test_take_profit_triggers(): def test_stop_not_triggered_within_range(): """No auto-close when price stays within stop/tp range.""" sim = OrderSimulator(Decimal("100000")) - signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.BUY, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("48000"), take_profit=Decimal("55000")) ts = datetime(2025, 1, 1, tzinfo=timezone.utc) @@ -189,10 +189,10 @@ def test_stop_not_triggered_within_range(): def test_short_sell_allowed(): """Can open short position with allow_short=True.""" sim = OrderSimulator(Decimal("100000"), allow_short=True) - signal = make_signal("BTCUSDT", OrderSide.SELL, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1") result = sim.execute(signal) assert result is True - assert sim.positions["BTCUSDT"] == Decimal("-0.1") + assert sim.positions["AAPL"] == Decimal("-0.1") assert len(sim.open_positions) == 1 assert sim.open_positions[0].side == OrderSide.SELL @@ -200,16 +200,16 @@ def test_short_sell_allowed(): def test_short_sell_rejected(): """Short rejected when allow_short=False (default).""" sim = OrderSimulator(Decimal("100000"), allow_short=False) - signal = make_signal("BTCUSDT", OrderSide.SELL, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1") result = sim.execute(signal) assert result is False - assert sim.positions.get("BTCUSDT", Decimal("0")) == Decimal("0") + assert sim.positions.get("AAPL", Decimal("0")) == Decimal("0") def test_short_stop_loss(): """Short position stop-loss triggers on candle high >= stop_loss.""" sim = OrderSimulator(Decimal("100000"), allow_short=True) - signal = make_signal("BTCUSDT", OrderSide.SELL, "50000", "0.1") + signal = make_signal("AAPL", OrderSide.SELL, "50000", "0.1") sim.execute(signal, stop_loss=Decimal("52000")) ts = datetime(2025, 1, 1, tzinfo=timezone.utc) diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py index 5ab2e7b..96abb6e 100644 --- a/services/backtester/tests/test_walk_forward.py +++ b/services/backtester/tests/test_walk_forward.py @@ -21,7 +21,7 @@ def _generate_candles(n=100, base_price=100.0): price = base_price + (i % 20) - 10 candles.append( Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1h", open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i), open=Decimal(str(price)), diff --git a/services/data-collector/pyproject.toml b/services/data-collector/pyproject.toml index 5fba78f..48282c3 100644 --- a/services/data-collector/pyproject.toml +++ b/services/data-collector/pyproject.toml @@ -1,13 +1,9 @@ [project] name = "data-collector" version = "0.1.0" -description = "Binance market data collector service" +description = "Alpaca market data collector service" requires-python = ">=3.12" -dependencies = [ - "ccxt>=4.0", - "websockets>=12.0", - "trading-shared", -] +dependencies = ["trading-shared"] [project.optional-dependencies] dev = [ diff --git a/services/data-collector/src/data_collector/binance_rest.py b/services/data-collector/src/data_collector/binance_rest.py deleted file mode 100644 index eaf4e30..0000000 --- a/services/data-collector/src/data_collector/binance_rest.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Binance REST API helpers for fetching historical candle data.""" - -from datetime import datetime, timezone -from decimal import Decimal - -from shared.models import Candle - - -def _normalize_symbol(symbol: str) -> str: - """Convert 'BTC/USDT' to 'BTCUSDT'.""" - return symbol.replace("/", "") - - -async def fetch_historical_candles( - exchange, - symbol: str, - timeframe: str, - since: int, - limit: int = 500, -) -> list[Candle]: - """Fetch historical OHLCV candles from the exchange and return Candle models. - - Args: - exchange: An async ccxt exchange instance. - symbol: Market symbol, e.g. 'BTC/USDT'. - timeframe: Candle timeframe, e.g. '1m'. - since: Start timestamp in milliseconds. - limit: Maximum number of candles to fetch. - - Returns: - A list of Candle model instances. - """ - rows = await exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) - - normalized = _normalize_symbol(symbol) - candles: list[Candle] = [] - - for row in rows: - ts_ms, o, h, low, c, v = row - open_time = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) - candles.append( - Candle( - symbol=normalized, - timeframe=timeframe, - open_time=open_time, - open=Decimal(str(o)), - high=Decimal(str(h)), - low=Decimal(str(low)), - close=Decimal(str(c)), - volume=Decimal(str(v)), - ) - ) - - return candles diff --git a/services/data-collector/src/data_collector/binance_ws.py b/services/data-collector/src/data_collector/binance_ws.py deleted file mode 100644 index e25e7a6..0000000 --- a/services/data-collector/src/data_collector/binance_ws.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Binance WebSocket client for real-time kline/candle data. - -NOTE: This module is Binance-specific (uses Binance WebSocket URL and message format). -Multi-exchange WebSocket support would require exchange-specific implementations. -""" - -import asyncio -import json -import logging -from datetime import datetime, timezone -from decimal import Decimal -from typing import Callable, Awaitable - -import websockets - -from shared.models import Candle - -logger = logging.getLogger(__name__) - -BINANCE_WS_URL = "wss://stream.binance.com:9443/ws" -RECONNECT_DELAY = 5 # seconds - - -def _normalize_symbol(symbol: str) -> str: - """Convert 'BTC/USDT' to 'BTCUSDT'.""" - return symbol.replace("/", "") - - -def _stream_name(symbol: str, timeframe: str) -> str: - """Build Binance stream name, e.g. 'btcusdt@kline_1m'.""" - return f"{_normalize_symbol(symbol).lower()}@kline_{timeframe}" - - -class BinanceWebSocket: - """Connects to Binance WebSocket streams and emits closed candles.""" - - def __init__( - self, - symbols: list[str], - timeframe: str, - on_candle: Callable[[Candle], Awaitable[None]], - ) -> None: - self._symbols = symbols - self._timeframe = timeframe - self._on_candle = on_candle - self._running = False - - def _build_subscribe_message(self) -> dict: - streams = [_stream_name(s, self._timeframe) for s in self._symbols] - return { - "method": "SUBSCRIBE", - "params": streams, - "id": 1, - } - - def _parse_candle(self, message: dict) -> Candle | None: - """Parse a kline WebSocket message into a Candle, or None if not closed.""" - k = message.get("k") - if k is None: - return None - if not k.get("x"): # only closed candles - return None - - symbol = k["s"] # already normalized, e.g. 'BTCUSDT' - open_time = datetime.fromtimestamp(k["t"] / 1000, tz=timezone.utc) - return Candle( - symbol=symbol, - timeframe=self._timeframe, - open_time=open_time, - open=Decimal(k["o"]), - high=Decimal(k["h"]), - low=Decimal(k["l"]), - close=Decimal(k["c"]), - volume=Decimal(k["v"]), - ) - - async def _run_once(self) -> None: - """Single connection attempt; processes messages until disconnected.""" - async with websockets.connect(BINANCE_WS_URL) as ws: - subscribe_msg = self._build_subscribe_message() - await ws.send(json.dumps(subscribe_msg)) - logger.info("Subscribed to Binance streams: %s", subscribe_msg["params"]) - - async for raw in ws: - if not self._running: - break - try: - message = json.loads(raw) - candle = self._parse_candle(message) - if candle is not None: - await self._on_candle(candle) - except Exception: - logger.exception("Error processing WebSocket message: %s", raw) - - async def start(self) -> None: - """Connect to Binance WebSocket and process messages, auto-reconnecting.""" - self._running = True - while self._running: - try: - await self._run_once() - except Exception: - if not self._running: - break - logger.warning("WebSocket disconnected. Reconnecting in %ds…", RECONNECT_DELAY) - await asyncio.sleep(RECONNECT_DELAY) - - def stop(self) -> None: - """Signal the WebSocket loop to stop after the current message.""" - self._running = False diff --git a/services/data-collector/src/data_collector/config.py b/services/data-collector/src/data_collector/config.py index 1e080e5..dd430e6 100644 --- a/services/data-collector/src/data_collector/config.py +++ b/services/data-collector/src/data_collector/config.py @@ -1,6 +1,9 @@ +"""Data Collector configuration.""" + from shared.config import Settings class CollectorConfig(Settings): - symbols: list[str] = ["BTC/USDT"] - timeframes: list[str] = ["1m"] + symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] + timeframes: list[str] = ["5Min"] + poll_interval_seconds: int = 60 diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py index eebe14a..b42b34c 100644 --- a/services/data-collector/src/data_collector/main.py +++ b/services/data-collector/src/data_collector/main.py @@ -1,59 +1,74 @@ -"""Data Collector Service entry point.""" +"""Data Collector Service — fetches US stock data from Alpaca.""" import asyncio +from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database +from shared.events import CandleEvent from shared.healthcheck import HealthCheckServer from shared.logging import setup_logging from shared.metrics import ServiceMetrics +from shared.models import Candle from shared.notifier import TelegramNotifier from data_collector.config import CollectorConfig -from data_collector.storage import CandleStorage -from data_collector.ws_factory import create_websocket - -# Health check port: base (HEALTH_PORT, default 8080) + offset -# data-collector: +0 (8080), strategy-engine: +1 (8081) -# order-executor: +2 (8082), portfolio-manager: +3 (8083) +# Health check port: base + 0 HEALTH_PORT_OFFSET = 0 +async def fetch_latest_bars( + alpaca: AlpacaClient, + symbols: list[str], + timeframe: str, + log, +) -> list[Candle]: + """Fetch latest bar for each symbol from Alpaca.""" + candles = [] + for symbol in symbols: + try: + bars = await alpaca.get_bars(symbol, timeframe=timeframe, limit=1) + if bars: + bar = bars[-1] + from datetime import datetime + from decimal import Decimal + + candle = Candle( + symbol=symbol, + timeframe=timeframe, + open_time=datetime.fromisoformat(bar["t"].replace("Z", "+00:00")), + open=Decimal(str(bar["o"])), + high=Decimal(str(bar["h"])), + low=Decimal(str(bar["l"])), + close=Decimal(str(bar["c"])), + volume=Decimal(str(bar["v"])), + ) + candles.append(candle) + except Exception as exc: + log.warning("fetch_bar_failed", symbol=symbol, error=str(exc)) + return candles + + async def run() -> None: - """Initialise all components and start the WebSocket collector.""" config = CollectorConfig() log = setup_logging("data-collector", config.log_level, config.log_format) metrics = ServiceMetrics("data_collector") + notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + bot_token=config.telegram_bot_token, + chat_id=config.telegram_chat_id, ) db = Database(config.database_url) await db.connect() - await db.init_tables() broker = RedisBroker(config.redis_url) - storage = CandleStorage(db=db, broker=broker) - - async def on_candle(candle): - log.info( - "candle_received", - symbol=candle.symbol, - timeframe=candle.timeframe, - open_time=str(candle.open_time), - ) - await storage.store(candle) - metrics.events_processed.labels(service="data-collector", event_type="candle").inc() - - # Use the first configured timeframe for the WebSocket subscription. - timeframe = config.timeframes[0] if config.timeframes else "1m" - - ws = create_websocket( - exchange_id=config.exchange_id, - symbols=config.symbols, - timeframe=timeframe, - on_candle=on_candle, + + alpaca = AlpacaClient( + api_key=config.alpaca_api_key, + api_secret=config.alpaca_api_secret, + paper=config.alpaca_paper, ) health = HealthCheckServer( @@ -61,18 +76,38 @@ async def run() -> None: port=config.health_port + HEALTH_PORT_OFFSET, auth_token=config.metrics_auth_token, ) - health.register_check("redis", broker.ping) await health.start() metrics.service_up.labels(service="data-collector").set(1) - log.info( - "service_started", - symbols=config.symbols, - timeframe=timeframe, - ) + poll_interval = int(getattr(config, "poll_interval_seconds", 60)) + symbols = config.symbols + timeframe = config.timeframes[0] if config.timeframes else "1Day" + + log.info("starting", symbols=symbols, timeframe=timeframe, poll_interval=poll_interval) try: - await ws.start() + while True: + # Check if market is open + try: + is_open = await alpaca.is_market_open() + except Exception: + is_open = False + + if is_open: + candles = await fetch_latest_bars(alpaca, symbols, timeframe, log) + for candle in candles: + await db.insert_candle(candle) + event = CandleEvent(data=candle) + stream = f"candles.{candle.symbol}" + await broker.publish(stream, event.to_dict()) + metrics.events_processed.labels( + service="data-collector", event_type="candle" + ).inc() + log.info("candle_stored", symbol=candle.symbol, close=str(candle.close)) + else: + log.debug("market_closed") + + await asyncio.sleep(poll_interval) except Exception as exc: log.error("fatal_error", error=str(exc)) await notifier.send_error(str(exc), "data-collector") @@ -80,6 +115,7 @@ async def run() -> None: finally: metrics.service_up.labels(service="data-collector").set(0) await notifier.close() + await alpaca.close() await broker.close() await db.close() diff --git a/services/data-collector/src/data_collector/ws_factory.py b/services/data-collector/src/data_collector/ws_factory.py deleted file mode 100644 index e068399..0000000 --- a/services/data-collector/src/data_collector/ws_factory.py +++ /dev/null @@ -1,34 +0,0 @@ -"""WebSocket factory for exchange-specific connections.""" - -import logging - -from data_collector.binance_ws import BinanceWebSocket - -logger = logging.getLogger(__name__) - -# Supported exchanges for WebSocket streaming -SUPPORTED_WS = {"binance": BinanceWebSocket} - - -def create_websocket(exchange_id: str, **kwargs): - """Create an exchange-specific WebSocket handler. - - Args: - exchange_id: Exchange identifier (e.g. 'binance') - **kwargs: Passed to the WebSocket constructor (symbols, timeframe, on_candle) - - Returns: - WebSocket handler instance - - Raises: - ValueError: If exchange is not supported for WebSocket streaming - """ - ws_cls = SUPPORTED_WS.get(exchange_id) - if ws_cls is None: - supported = ", ".join(sorted(SUPPORTED_WS.keys())) - raise ValueError( - f"WebSocket streaming not supported for '{exchange_id}'. " - f"Supported: {supported}. " - f"Use REST polling as fallback for unsupported exchanges." - ) - return ws_cls(**kwargs) diff --git a/services/data-collector/tests/test_binance_rest.py b/services/data-collector/tests/test_binance_rest.py deleted file mode 100644 index bf88210..0000000 --- a/services/data-collector/tests/test_binance_rest.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Tests for binance_rest module.""" - -import pytest -from decimal import Decimal -from unittest.mock import AsyncMock, MagicMock -from datetime import datetime, timezone - -from data_collector.binance_rest import fetch_historical_candles - - -@pytest.mark.asyncio -async def test_fetch_historical_candles_parses_response(): - """Verify that OHLCV rows are correctly parsed into Candle models.""" - ts = 1700000000000 # milliseconds - mock_exchange = MagicMock() - mock_exchange.fetch_ohlcv = AsyncMock( - return_value=[ - [ts, 30000.0, 30100.0, 29900.0, 30050.0, 1.5], - [ts + 60000, 30050.0, 30200.0, 30000.0, 30150.0, 2.0], - ] - ) - - candles = await fetch_historical_candles(mock_exchange, "BTC/USDT", "1m", since=ts, limit=500) - - assert len(candles) == 2 - - c = candles[0] - assert c.symbol == "BTCUSDT" - assert c.timeframe == "1m" - assert c.open_time == datetime.fromtimestamp(ts / 1000, tz=timezone.utc) - assert c.open == Decimal("30000.0") - assert c.high == Decimal("30100.0") - assert c.low == Decimal("29900.0") - assert c.close == Decimal("30050.0") - assert c.volume == Decimal("1.5") - - mock_exchange.fetch_ohlcv.assert_called_once_with("BTC/USDT", "1m", since=ts, limit=500) - - -@pytest.mark.asyncio -async def test_fetch_historical_candles_empty_response(): - """Verify that an empty exchange response returns an empty list.""" - mock_exchange = MagicMock() - mock_exchange.fetch_ohlcv = AsyncMock(return_value=[]) - - candles = await fetch_historical_candles(mock_exchange, "BTC/USDT", "1m", since=1700000000000) - - assert candles == [] diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py index be85578..ffffa40 100644 --- a/services/data-collector/tests/test_storage.py +++ b/services/data-collector/tests/test_storage.py @@ -9,7 +9,7 @@ from shared.models import Candle from data_collector.storage import CandleStorage -def _make_candle(symbol: str = "BTCUSDT") -> Candle: +def _make_candle(symbol: str = "AAPL") -> Candle: return Candle( symbol=symbol, timeframe="1m", @@ -39,11 +39,11 @@ async def test_storage_saves_to_db_and_publishes(): mock_broker.publish.assert_called_once() stream_arg = mock_broker.publish.call_args[0][0] - assert stream_arg == "candles.BTCUSDT" + assert stream_arg == "candles.AAPL" data_arg = mock_broker.publish.call_args[0][1] assert data_arg["type"] == "CANDLE" - assert data_arg["data"]["symbol"] == "BTCUSDT" + assert data_arg["data"]["symbol"] == "AAPL" @pytest.mark.asyncio diff --git a/services/data-collector/tests/test_ws_factory.py b/services/data-collector/tests/test_ws_factory.py deleted file mode 100644 index cdddcca..0000000 --- a/services/data-collector/tests/test_ws_factory.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Tests for WebSocket factory.""" - -import pytest -from data_collector.ws_factory import create_websocket, SUPPORTED_WS -from data_collector.binance_ws import BinanceWebSocket - - -def test_create_binance_ws(): - ws = create_websocket("binance", symbols=["BTCUSDT"], timeframe="1m", on_candle=lambda c: None) - assert isinstance(ws, BinanceWebSocket) - - -def test_create_unsupported_exchange(): - with pytest.raises(ValueError, match="not supported"): - create_websocket( - "unsupported_exchange", symbols=["BTCUSDT"], timeframe="1m", on_candle=lambda c: None - ) - - -def test_supported_exchanges(): - assert "binance" in SUPPORTED_WS diff --git a/services/order-executor/pyproject.toml b/services/order-executor/pyproject.toml index eed4fef..7bb1030 100644 --- a/services/order-executor/pyproject.toml +++ b/services/order-executor/pyproject.toml @@ -3,7 +3,7 @@ name = "order-executor" version = "0.1.0" description = "Order execution service with risk management" requires-python = ">=3.12" -dependencies = ["ccxt>=4.0", "trading-shared"] +dependencies = ["trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py index 80f441d..a71e762 100644 --- a/services/order-executor/src/order_executor/executor.py +++ b/services/order-executor/src/order_executor/executor.py @@ -37,12 +37,8 @@ class OrderExecutor: async def execute(self, signal: Signal) -> Optional[Order]: """Run risk checks and place an order for the given signal.""" - # Fetch current balance from exchange - balance_data = await self.exchange.fetch_balance() - # Use USDT (or quote currency) free balance as available capital - free_balances = balance_data.get("free", {}) - quote_currency = signal.symbol.split("/")[-1] if "/" in signal.symbol else "USDT" - balance = Decimal(str(free_balances.get(quote_currency, 0))) + # Fetch buying power from Alpaca + balance = await self.exchange.get_buying_power() # Fetch current positions positions = {} @@ -84,11 +80,11 @@ class OrderExecutor: ) else: try: - await self.exchange.create_order( + await self.exchange.submit_order( symbol=signal.symbol, - type="market", + qty=float(signal.quantity), side=signal.side.value.lower(), - amount=float(signal.quantity), + type="market", ) order.status = OrderStatus.FILLED order.filled_at = datetime.now(timezone.utc) diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py index 3fe4c12..51ab286 100644 --- a/services/order-executor/src/order_executor/main.py +++ b/services/order-executor/src/order_executor/main.py @@ -3,11 +3,11 @@ import asyncio from decimal import Decimal +from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database from shared.events import Event, EventType from shared.healthcheck import HealthCheckServer -from shared.exchange import create_exchange from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier @@ -16,9 +16,7 @@ from order_executor.config import ExecutorConfig from order_executor.executor import OrderExecutor from order_executor.risk_manager import RiskManager -# Health check port: base (HEALTH_PORT, default 8080) + offset -# data-collector: +0 (8080), strategy-engine: +1 (8081) -# order-executor: +2 (8082), portfolio-manager: +3 (8083) +# Health check port: base + 2 HEALTH_PORT_OFFSET = 2 @@ -26,21 +24,21 @@ async def run() -> None: config = ExecutorConfig() log = setup_logging("order-executor", config.log_level, config.log_format) metrics = ServiceMetrics("order_executor") + notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + bot_token=config.telegram_bot_token, + chat_id=config.telegram_chat_id, ) db = Database(config.database_url) await db.connect() - await db.init_tables() broker = RedisBroker(config.redis_url) - exchange = create_exchange( - exchange_id=config.exchange_id, - api_key=config.binance_api_key, - api_secret=config.binance_api_secret, - sandbox=config.exchange_sandbox, + alpaca = AlpacaClient( + api_key=config.alpaca_api_key, + api_secret=config.alpaca_api_secret, + paper=config.alpaca_paper, ) risk_manager = RiskManager( @@ -51,10 +49,19 @@ async def run() -> None: max_open_positions=config.risk_max_open_positions, volatility_lookback=config.risk_volatility_lookback, volatility_scale=config.risk_volatility_scale, + max_portfolio_exposure=config.risk_max_portfolio_exposure, + max_correlated_exposure=config.risk_max_correlated_exposure, + correlation_threshold=config.risk_correlation_threshold, + var_confidence=config.risk_var_confidence, + var_limit_pct=config.risk_var_limit_pct, + drawdown_reduction_threshold=config.risk_drawdown_reduction_threshold, + drawdown_halt_threshold=config.risk_drawdown_halt_threshold, + max_consecutive_losses=config.risk_max_consecutive_losses, + loss_pause_minutes=config.risk_loss_pause_minutes, ) executor = OrderExecutor( - exchange=exchange, + exchange=alpaca, risk_manager=risk_manager, broker=broker, db=db, @@ -62,41 +69,34 @@ async def run() -> None: dry_run=config.dry_run, ) - GROUP = "order-executor" - CONSUMER = "executor-1" - stream = "signals" - health = HealthCheckServer( "order-executor", port=config.health_port + HEALTH_PORT_OFFSET, auth_token=config.metrics_auth_token, ) - health.register_check("redis", broker.ping) await health.start() metrics.service_up.labels(service="order-executor").set(1) - log.info("service_started", stream=stream, dry_run=config.dry_run) + GROUP = "order-executor" + CONSUMER = "executor-1" + stream = "signals" await broker.ensure_group(stream, GROUP) - # Process pending messages first (from previous crash) - pending = await broker.read_pending(stream, GROUP, CONSUMER) - for msg_id, msg in pending: - try: - event = Event.from_dict(msg) - if event.type == EventType.SIGNAL: - signal = event.data - log.info( - "processing_pending_signal", signal_id=str(signal.id), symbol=signal.symbol - ) - await executor.execute(signal) - metrics.events_processed.labels(service="order-executor", event_type="signal").inc() - await broker.ack(stream, GROUP, msg_id) - except Exception as exc: - log.error("pending_process_failed", error=str(exc), msg_id=msg_id) - metrics.errors_total.labels(service="order-executor", error_type="processing").inc() + log.info("started", stream=stream, dry_run=config.dry_run) try: + # Process pending messages first + pending = await broker.read_pending(stream, GROUP, CONSUMER) + for msg_id, msg in pending: + try: + event = Event.from_dict(msg) + if event.type == EventType.SIGNAL: + await executor.execute(event.data) + await broker.ack(stream, GROUP, msg_id) + except Exception as exc: + log.error("pending_failed", error=str(exc), msg_id=msg_id) + while True: messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000) for msg_id, msg in messages: @@ -104,29 +104,23 @@ async def run() -> None: event = Event.from_dict(msg) if event.type == EventType.SIGNAL: signal = event.data - log.info( - "processing_signal", signal_id=str(signal.id), symbol=signal.symbol - ) + log.info("processing_signal", signal_id=signal.id, symbol=signal.symbol) await executor.execute(signal) metrics.events_processed.labels( service="order-executor", event_type="signal" ).inc() await broker.ack(stream, GROUP, msg_id) except Exception as exc: - log.error("message_processing_failed", error=str(exc), msg_id=msg_id) + log.error("process_failed", error=str(exc)) metrics.errors_total.labels( service="order-executor", error_type="processing" ).inc() - except Exception as exc: - log.error("fatal_error", error=str(exc)) - await notifier.send_error(str(exc), "order-executor") - raise finally: metrics.service_up.labels(service="order-executor").set(0) await notifier.close() await broker.close() await db.close() - await exchange.close() + await alpaca.close() def main() -> None: diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py index c3578a7..5a05746 100644 --- a/services/order-executor/src/order_executor/risk_manager.py +++ b/services/order-executor/src/order_executor/risk_manager.py @@ -1,6 +1,7 @@ """Risk management for order execution.""" from dataclasses import dataclass +from datetime import datetime, timezone, timedelta from decimal import Decimal from collections import deque import math @@ -46,6 +47,15 @@ class RiskManager: max_open_positions: int = 10, volatility_lookback: int = 20, volatility_scale: bool = False, + max_portfolio_exposure: float = 0.8, + max_correlated_exposure: float = 0.5, + correlation_threshold: float = 0.7, + var_confidence: float = 0.95, + var_limit_pct: float = 5.0, + drawdown_reduction_threshold: float = 0.1, # Start reducing at 10% drawdown + drawdown_halt_threshold: float = 0.2, # Halt trading at 20% drawdown + max_consecutive_losses: int = 5, # Pause after 5 consecutive losses + loss_pause_minutes: int = 60, # Pause for 60 minutes after consecutive losses ) -> None: self.max_position_size = max_position_size self.stop_loss_pct = stop_loss_pct @@ -57,6 +67,75 @@ class RiskManager: self._trailing_stops: dict[str, TrailingStop] = {} self._price_history: dict[str, deque[float]] = {} + self._return_history: dict[str, list[float]] = {} + self._max_portfolio_exposure = Decimal(str(max_portfolio_exposure)) + self._max_correlated_exposure = Decimal(str(max_correlated_exposure)) + self._correlation_threshold = correlation_threshold + self._var_confidence = var_confidence + self._var_limit_pct = Decimal(str(var_limit_pct)) + + self._drawdown_reduction_threshold = drawdown_reduction_threshold + self._drawdown_halt_threshold = drawdown_halt_threshold + self._max_consecutive_losses = max_consecutive_losses + self._loss_pause_minutes = loss_pause_minutes + + self._peak_balance: Decimal = Decimal("0") + self._consecutive_losses: int = 0 + self._paused_until: datetime | None = None + + def update_balance(self, current_balance: Decimal) -> None: + """Track peak balance for drawdown calculation.""" + if current_balance > self._peak_balance: + self._peak_balance = current_balance + + def get_current_drawdown(self, current_balance: Decimal) -> float: + """Calculate current drawdown from peak as a fraction (0.0 to 1.0).""" + if self._peak_balance <= 0: + return 0.0 + dd = float((self._peak_balance - current_balance) / self._peak_balance) + return max(dd, 0.0) + + def get_position_scale(self, current_balance: Decimal) -> float: + """Get position size multiplier based on current drawdown. + + Returns 1.0 (full size) when no drawdown. + Linearly reduces to 0.25 between reduction threshold and halt threshold. + Returns 0.0 at or beyond halt threshold. + """ + dd = self.get_current_drawdown(current_balance) + + if dd >= self._drawdown_halt_threshold: + return 0.0 + + if dd >= self._drawdown_reduction_threshold: + # Linear interpolation from 1.0 to 0.25 + range_pct = (dd - self._drawdown_reduction_threshold) / ( + self._drawdown_halt_threshold - self._drawdown_reduction_threshold + ) + return max(1.0 - 0.75 * range_pct, 0.25) + + return 1.0 + + def record_trade_result(self, is_win: bool) -> None: + """Record a trade result for consecutive loss tracking.""" + if is_win: + self._consecutive_losses = 0 + else: + self._consecutive_losses += 1 + if self._consecutive_losses >= self._max_consecutive_losses: + self._paused_until = datetime.now(timezone.utc) + timedelta( + minutes=self._loss_pause_minutes + ) + + def is_paused(self) -> bool: + """Check if trading is paused due to consecutive losses.""" + if self._paused_until is None: + return False + if datetime.now(timezone.utc) >= self._paused_until: + self._paused_until = None + self._consecutive_losses = 0 + return False + return True def update_price(self, symbol: str, price: Decimal) -> None: """Update price tracking for trailing stops and volatility.""" @@ -120,6 +199,145 @@ class RiskManager: scale = min(target_vol / vol, 2.0) # Cap at 2x return base_size * Decimal(str(scale)) + def calculate_correlation(self, symbol_a: str, symbol_b: str) -> float | None: + """Calculate Pearson correlation between two symbols' returns.""" + hist_a = self._price_history.get(symbol_a) + hist_b = self._price_history.get(symbol_b) + if not hist_a or not hist_b or len(hist_a) < 5 or len(hist_b) < 5: + return None + + prices_a = list(hist_a) + prices_b = list(hist_b) + min_len = min(len(prices_a), len(prices_b)) + prices_a = prices_a[-min_len:] + prices_b = prices_b[-min_len:] + + returns_a = [ + (prices_a[i] - prices_a[i - 1]) / prices_a[i - 1] + for i in range(1, len(prices_a)) + if prices_a[i - 1] != 0 + ] + returns_b = [ + (prices_b[i] - prices_b[i - 1]) / prices_b[i - 1] + for i in range(1, len(prices_b)) + if prices_b[i - 1] != 0 + ] + + if len(returns_a) < 3 or len(returns_b) < 3: + return None + + min_len = min(len(returns_a), len(returns_b)) + returns_a = returns_a[-min_len:] + returns_b = returns_b[-min_len:] + + mean_a = sum(returns_a) / len(returns_a) + mean_b = sum(returns_b) / len(returns_b) + + cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len( + returns_a + ) + std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a)) + std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b)) + + if std_a == 0 or std_b == 0: + return None + + return cov / (std_a * std_b) + + def calculate_portfolio_var(self, positions: dict[str, Position], balance: Decimal) -> float: + """Calculate portfolio VaR using historical simulation. + + Returns VaR as a percentage of balance (e.g., 3.5 for 3.5%). + """ + if not positions or balance <= 0: + return 0.0 + + # Collect returns for all positioned symbols + all_returns: list[list[float]] = [] + weights: list[float] = [] + + for symbol, pos in positions.items(): + if pos.quantity <= 0: + continue + hist = self._price_history.get(symbol) + if not hist or len(hist) < 5: + continue + prices = list(hist) + returns = [ + (prices[i] - prices[i - 1]) / prices[i - 1] + for i in range(1, len(prices)) + if prices[i - 1] != 0 + ] + if returns: + all_returns.append(returns) + weight = float(pos.quantity * pos.current_price / balance) + weights.append(weight) + + if not all_returns: + return 0.0 + + # Portfolio returns (weighted sum) + min_len = min(len(r) for r in all_returns) + portfolio_returns = [] + for i in range(min_len): + pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i) + portfolio_returns.append(pr) + + if not portfolio_returns: + return 0.0 + + # Historical VaR: sort returns, take the (1-confidence) percentile + sorted_returns = sorted(portfolio_returns) + index = int((1 - self._var_confidence) * len(sorted_returns)) + index = max(0, min(index, len(sorted_returns) - 1)) + var_return = sorted_returns[index] + + return abs(var_return) * 100 # As percentage + + def check_portfolio_exposure( + self, positions: dict[str, Position], balance: Decimal + ) -> RiskCheckResult: + """Check total portfolio exposure.""" + if balance <= 0: + return RiskCheckResult(allowed=True, reason="OK") + + total_exposure = sum( + pos.quantity * pos.current_price for pos in positions.values() if pos.quantity > 0 + ) + + exposure_ratio = total_exposure / balance + if exposure_ratio > self._max_portfolio_exposure: + return RiskCheckResult( + allowed=False, + reason=f"Portfolio exposure {float(exposure_ratio):.1%} exceeds max {float(self._max_portfolio_exposure):.1%}", + ) + + return RiskCheckResult(allowed=True, reason="OK") + + def check_correlation_risk( + self, signal: Signal, positions: dict[str, Position], balance: Decimal + ) -> RiskCheckResult: + """Check if adding this position creates too much correlated exposure.""" + if signal.side != OrderSide.BUY or balance <= 0: + return RiskCheckResult(allowed=True, reason="OK") + + correlated_value = signal.price * signal.quantity + + for symbol, pos in positions.items(): + if pos.quantity <= 0 or symbol == signal.symbol: + continue + corr = self.calculate_correlation(signal.symbol, symbol) + if corr is not None and abs(corr) >= self._correlation_threshold: + correlated_value += pos.quantity * pos.current_price + + if correlated_value / balance > self._max_correlated_exposure: + return RiskCheckResult( + allowed=False, + reason=f"Correlated exposure would exceed {float(self._max_correlated_exposure):.1%}", + ) + + return RiskCheckResult(allowed=True, reason="OK") + def check( self, signal: Signal, @@ -128,6 +346,21 @@ class RiskManager: daily_pnl: Decimal, ) -> RiskCheckResult: """Run risk checks against a signal and current portfolio state.""" + # Check if paused due to consecutive losses + if self.is_paused(): + return RiskCheckResult( + allowed=False, + reason=f"Trading paused until {self._paused_until.isoformat()} after {self._max_consecutive_losses} consecutive losses", + ) + + # Check drawdown halt + dd = self.get_current_drawdown(balance) + if dd >= self._drawdown_halt_threshold: + return RiskCheckResult( + allowed=False, + reason=f"Trading halted: drawdown {dd:.1%} exceeds {self._drawdown_halt_threshold:.1%}", + ) + # Check daily loss limit if balance > 0 and (daily_pnl / balance) * 100 < -self.daily_loss_limit_pct: return RiskCheckResult(allowed=False, reason="Daily loss limit exceeded") @@ -165,4 +398,22 @@ class RiskManager: ): return RiskCheckResult(allowed=False, reason="Position size exceeded") + # Portfolio-level checks + exposure_check = self.check_portfolio_exposure(positions, balance) + if not exposure_check.allowed: + return exposure_check + + corr_check = self.check_correlation_risk(signal, positions, balance) + if not corr_check.allowed: + return corr_check + + # VaR check + if positions: + var = self.calculate_portfolio_var(positions, balance) + if var > float(self._var_limit_pct): + return RiskCheckResult( + allowed=False, + reason=f"Portfolio VaR {var:.1f}% exceeds limit {float(self._var_limit_pct):.1f}%", + ) + return RiskCheckResult(allowed=True, reason="OK") diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py index e64b6c0..dd823d7 100644 --- a/services/order-executor/tests/test_executor.py +++ b/services/order-executor/tests/test_executor.py @@ -13,7 +13,7 @@ from order_executor.risk_manager import RiskCheckResult, RiskManager def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal: return Signal( strategy="test", - symbol="BTC/USDT", + symbol="AAPL", side=side, price=Decimal(price), quantity=Decimal(quantity), @@ -21,10 +21,10 @@ def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: s ) -def make_mock_exchange(free_usdt: float = 10000.0) -> AsyncMock: +def make_mock_exchange(buying_power: str = "10000") -> AsyncMock: exchange = AsyncMock() - exchange.fetch_balance.return_value = {"free": {"USDT": free_usdt}} - exchange.create_order = AsyncMock(return_value={"id": "exchange-order-123"}) + exchange.get_buying_power = AsyncMock(return_value=Decimal(buying_power)) + exchange.submit_order = AsyncMock(return_value={"id": "alpaca-order-123"}) return exchange @@ -48,7 +48,7 @@ def make_mock_db() -> AsyncMock: @pytest.mark.asyncio async def test_executor_places_order_when_risk_passes(): - """When risk check passes, create_order is called and order status is FILLED.""" + """When risk check passes, submit_order is called and order status is FILLED.""" exchange = make_mock_exchange() risk_manager = make_mock_risk_manager(allowed=True) broker = make_mock_broker() @@ -68,14 +68,14 @@ async def test_executor_places_order_when_risk_passes(): assert order is not None assert order.status == OrderStatus.FILLED - exchange.create_order.assert_called_once() + exchange.submit_order.assert_called_once() db.insert_order.assert_called_once_with(order) broker.publish.assert_called_once() @pytest.mark.asyncio async def test_executor_rejects_when_risk_fails(): - """When risk check fails, create_order is not called and None is returned.""" + """When risk check fails, submit_order is not called and None is returned.""" exchange = make_mock_exchange() risk_manager = make_mock_risk_manager(allowed=False, reason="Position size exceeded") broker = make_mock_broker() @@ -94,14 +94,14 @@ async def test_executor_rejects_when_risk_fails(): order = await executor.execute(signal) assert order is None - exchange.create_order.assert_not_called() + exchange.submit_order.assert_not_called() db.insert_order.assert_not_called() broker.publish.assert_not_called() @pytest.mark.asyncio async def test_executor_dry_run_does_not_call_exchange(): - """In dry-run mode, risk passes, order is FILLED, but exchange.create_order is NOT called.""" + """In dry-run mode, risk passes, order is FILLED, but exchange.submit_order is NOT called.""" exchange = make_mock_exchange() risk_manager = make_mock_risk_manager(allowed=True) broker = make_mock_broker() @@ -121,6 +121,6 @@ async def test_executor_dry_run_does_not_call_exchange(): assert order is not None assert order.status == OrderStatus.FILLED - exchange.create_order.assert_not_called() + exchange.submit_order.assert_not_called() db.insert_order.assert_called_once_with(order) broker.publish.assert_called_once() diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py index efabe73..3d5175b 100644 --- a/services/order-executor/tests/test_risk_manager.py +++ b/services/order-executor/tests/test_risk_manager.py @@ -7,7 +7,7 @@ from shared.models import OrderSide, Position, Signal from order_executor.risk_manager import RiskManager -def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "BTC/USDT") -> Signal: +def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal: return Signal( strategy="test", symbol=symbol, @@ -93,7 +93,7 @@ def test_risk_check_rejects_insufficient_balance(): def test_trailing_stop_set_and_trigger(): """Trailing stop should trigger when price drops below stop level.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) signal = make_signal(side=OrderSide.BUY, price="94", quantity="0.01") result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0")) @@ -104,10 +104,10 @@ def test_trailing_stop_set_and_trigger(): def test_trailing_stop_updates_highest_price(): """Trailing stop should track the highest price seen.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) # Price rises to 120 => stop at 114 - rm.update_price("BTC/USDT", Decimal("120")) + rm.update_price("AAPL", Decimal("120")) # Price at 115 is above stop (114), should be allowed signal = make_signal(side=OrderSide.BUY, price="115", quantity="0.01") @@ -124,7 +124,7 @@ def test_trailing_stop_updates_highest_price(): def test_trailing_stop_not_triggered_above_stop(): """Trailing stop should not trigger when price is above stop level.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) # Price at 96 is above stop (95), should be allowed signal = make_signal(side=OrderSide.BUY, price="96", quantity="0.01") @@ -140,11 +140,11 @@ def test_max_open_positions_check(): rm = make_risk_manager(max_open_positions=2) positions = { - "BTC/USDT": make_position("BTC/USDT", "1", "100", "100"), - "ETH/USDT": make_position("ETH/USDT", "10", "50", "50"), + "AAPL": make_position("AAPL", "1", "100", "100"), + "MSFT": make_position("MSFT", "10", "50", "50"), } - signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="SOL/USDT") + signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="TSLA") result = rm.check(signal, balance=Decimal("10000"), positions=positions, daily_pnl=Decimal("0")) assert result.allowed is False assert result.reason == "Max open positions reached" @@ -158,14 +158,14 @@ def test_volatility_calculation(): rm = make_risk_manager(volatility_lookback=5) # No history yet - assert rm.get_volatility("BTC/USDT") is None + assert rm.get_volatility("AAPL") is None # Feed prices prices = [100, 102, 98, 105, 101] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - vol = rm.get_volatility("BTC/USDT") + vol = rm.get_volatility("AAPL") assert vol is not None assert vol > 0 @@ -177,9 +177,9 @@ def test_position_size_with_volatility_scaling(): # Feed volatile prices prices = [100, 120, 80, 130, 70] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - size = rm.calculate_position_size("BTC/USDT", Decimal("10000")) + size = rm.calculate_position_size("AAPL", Decimal("10000")) base = Decimal("10000") * Decimal("0.1") # High volatility should reduce size below base @@ -192,9 +192,177 @@ def test_position_size_without_scaling(): prices = [100, 120, 80, 130, 70] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - size = rm.calculate_position_size("BTC/USDT", Decimal("10000")) + size = rm.calculate_position_size("AAPL", Decimal("10000")) base = Decimal("10000") * Decimal("0.1") assert size == base + + +# --- Portfolio exposure tests --- + + +def test_portfolio_exposure_check_passes(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_portfolio_exposure=0.8, + ) + positions = { + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("0.01"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("50000"), + ) + } + result = rm.check_portfolio_exposure(positions, Decimal("10000")) + assert result.allowed # 500/10000 = 5% < 80% + + +def test_portfolio_exposure_check_rejects(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_portfolio_exposure=0.3, + ) + positions = { + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("1"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("50000"), + ) + } + result = rm.check_portfolio_exposure(positions, Decimal("10000")) + assert not result.allowed # 50000/10000 = 500% > 30% + + +def test_correlation_calculation(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + ) + # Feed identical price histories — correlation should be ~1.0 + for i in range(20): + rm.update_price("A", Decimal(str(100 + i))) + rm.update_price("B", Decimal(str(100 + i))) + corr = rm.calculate_correlation("A", "B") + assert corr is not None + assert corr > 0.9 + + +def test_var_calculation(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + ) + for i in range(30): + rm.update_price("AAPL", Decimal(str(100 + (i % 5) - 2))) + positions = { + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("1"), + avg_entry_price=Decimal("100"), + current_price=Decimal("100"), + ) + } + var = rm.calculate_portfolio_var(positions, Decimal("10000")) + assert var >= 0 # Non-negative + + +# --- Drawdown position scaling tests --- + + +def test_drawdown_position_scale_full(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_reduction_threshold=0.1, + drawdown_halt_threshold=0.2, + ) + rm.update_balance(Decimal("10000")) + scale = rm.get_position_scale(Decimal("10000")) + assert scale == 1.0 # No drawdown + + +def test_drawdown_position_scale_reduced(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_reduction_threshold=0.1, + drawdown_halt_threshold=0.2, + ) + rm.update_balance(Decimal("10000")) + scale = rm.get_position_scale(Decimal("8500")) # 15% drawdown (between 10% and 20%) + assert 0.25 < scale < 1.0 + + +def test_drawdown_halt(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_reduction_threshold=0.1, + drawdown_halt_threshold=0.2, + ) + rm.update_balance(Decimal("10000")) + scale = rm.get_position_scale(Decimal("7500")) # 25% drawdown + assert scale == 0.0 + + +def test_consecutive_losses_pause(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_consecutive_losses=3, + loss_pause_minutes=60, + ) + rm.record_trade_result(False) + rm.record_trade_result(False) + assert not rm.is_paused() + rm.record_trade_result(False) # 3rd loss + assert rm.is_paused() + + +def test_consecutive_losses_reset_on_win(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_consecutive_losses=3, + ) + rm.record_trade_result(False) + rm.record_trade_result(False) + rm.record_trade_result(True) # Win resets counter + rm.record_trade_result(False) + assert not rm.is_paused() # Only 1 consecutive loss + + +def test_drawdown_check_rejects_in_check(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_halt_threshold=0.15, + ) + rm.update_balance(Decimal("10000")) + signal = Signal( + strategy="test", + symbol="AAPL", + side=OrderSide.BUY, + price=Decimal("50000"), + quantity=Decimal("0.01"), + reason="test", + ) + result = rm.check(signal, Decimal("8000"), {}, Decimal("0")) # 20% dd > 15% + assert not result.allowed + assert "halted" in result.reason.lower() diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py index 768e071..365dc1a 100644 --- a/services/portfolio-manager/tests/test_portfolio.py +++ b/services/portfolio-manager/tests/test_portfolio.py @@ -10,7 +10,7 @@ def make_order(side: OrderSide, price: str, quantity: str) -> Order: """Helper to create a filled Order.""" return Order( signal_id="test-signal", - symbol="BTC/USDT", + symbol="AAPL", side=side, type=OrderType.MARKET, price=Decimal(price), @@ -24,7 +24,7 @@ def test_portfolio_add_buy_order() -> None: order = make_order(OrderSide.BUY, "50000", "0.1") tracker.apply_order(order) - position = tracker.get_position("BTC/USDT") + position = tracker.get_position("AAPL") assert position is not None assert position.quantity == Decimal("0.1") assert position.avg_entry_price == Decimal("50000") @@ -35,7 +35,7 @@ def test_portfolio_add_multiple_buys() -> None: tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.1")) tracker.apply_order(make_order(OrderSide.BUY, "52000", "0.1")) - position = tracker.get_position("BTC/USDT") + position = tracker.get_position("AAPL") assert position is not None assert position.quantity == Decimal("0.2") assert position.avg_entry_price == Decimal("51000") @@ -46,7 +46,7 @@ def test_portfolio_sell_reduces_position() -> None: tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.2")) tracker.apply_order(make_order(OrderSide.SELL, "55000", "0.1")) - position = tracker.get_position("BTC/USDT") + position = tracker.get_position("AAPL") assert position is not None assert position.quantity == Decimal("0.1") assert position.avg_entry_price == Decimal("50000") @@ -54,7 +54,7 @@ def test_portfolio_sell_reduces_position() -> None: def test_portfolio_no_position_returns_none() -> None: tracker = PortfolioTracker() - position = tracker.get_position("ETH/USDT") + position = tracker.get_position("MSFT") assert position is None @@ -66,7 +66,7 @@ def test_realized_pnl_on_sell() -> None: tracker.apply_order( Order( signal_id="s1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000"), @@ -80,7 +80,7 @@ def test_realized_pnl_on_sell() -> None: tracker.apply_order( Order( signal_id="s2", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("55000"), @@ -98,7 +98,7 @@ def test_realized_pnl_on_loss() -> None: tracker.apply_order( Order( signal_id="s1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000"), @@ -109,7 +109,7 @@ def test_realized_pnl_on_loss() -> None: tracker.apply_order( Order( signal_id="s2", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("45000"), @@ -128,7 +128,7 @@ def test_realized_pnl_accumulates() -> None: tracker.apply_order( Order( signal_id="s1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000"), @@ -141,7 +141,7 @@ def test_realized_pnl_accumulates() -> None: tracker.apply_order( Order( signal_id="s2", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("55000"), @@ -154,7 +154,7 @@ def test_realized_pnl_accumulates() -> None: tracker.apply_order( Order( signal_id="s3", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.SELL, type=OrderType.MARKET, price=Decimal("60000"), diff --git a/services/portfolio-manager/tests/test_snapshot.py b/services/portfolio-manager/tests/test_snapshot.py index a464599..ec5e92d 100644 --- a/services/portfolio-manager/tests/test_snapshot.py +++ b/services/portfolio-manager/tests/test_snapshot.py @@ -13,7 +13,7 @@ class TestSaveSnapshot: from portfolio_manager.main import save_snapshot pos = Position( - symbol="BTCUSDT", + symbol="AAPL", quantity=Decimal("0.5"), avg_entry_price=Decimal("50000"), current_price=Decimal("52000"), diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index e3a49c2..9fd9c49 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -4,6 +4,6 @@ from shared.config import Settings class StrategyConfig(Settings): - symbols: list[str] = ["BTC/USDT"] + symbols: list[str] = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] timeframes: list[str] = ["1m"] strategy_params: dict = {} diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 4549f70..30de528 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -67,7 +67,6 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) - # Wait for all symbol processors (they run forever until cancelled) await asyncio.gather(*tasks) except Exception as exc: log.error("fatal_error", error=str(exc)) diff --git a/services/strategy-engine/strategies/bollinger_strategy.py b/services/strategy-engine/strategies/bollinger_strategy.py index e53ecaa..ebe7967 100644 --- a/services/strategy-engine/strategies/bollinger_strategy.py +++ b/services/strategy-engine/strategies/bollinger_strategy.py @@ -19,6 +19,9 @@ class BollingerStrategy(BaseStrategy): self._quantity: Decimal = Decimal("0.01") self._was_below_lower: bool = False self._was_above_upper: bool = False + self._squeeze_threshold: float = 0.01 # Bandwidth below this = squeeze + self._in_squeeze: bool = False + self._squeeze_bars: int = 0 # How many bars in squeeze @property def warmup_period(self) -> int: @@ -28,6 +31,7 @@ class BollingerStrategy(BaseStrategy): self._period = int(params.get("period", 20)) self._num_std = float(params.get("num_std", 2.0)) self._min_bandwidth = float(params.get("min_bandwidth", 0.02)) + self._squeeze_threshold = float(params.get("squeeze_threshold", 0.01)) self._quantity = Decimal(str(params.get("quantity", "0.01"))) if self._period < 2: @@ -46,9 +50,12 @@ class BollingerStrategy(BaseStrategy): ) def reset(self) -> None: + super().reset() self._closes.clear() self._was_below_lower = False self._was_above_upper = False + self._in_squeeze = False + self._squeeze_bars = 0 def _bollinger_conviction(self, price: float, band: float, sma: float) -> float: """Map distance from band to conviction (0.1-1.0). @@ -75,12 +82,56 @@ class BollingerStrategy(BaseStrategy): upper = sma + self._num_std * std lower = sma - self._num_std * std + price = float(candle.close) + + # %B calculation + bandwidth = (upper - lower) / sma if sma > 0 else 0 + pct_b = (price - lower) / (upper - lower) if (upper - lower) > 0 else 0.5 + + # Squeeze detection + if bandwidth < self._squeeze_threshold: + self._in_squeeze = True + self._squeeze_bars += 1 + return None # Don't trade during squeeze, wait for breakout + elif self._in_squeeze: + # Squeeze just ended — breakout! + self._in_squeeze = False + squeeze_duration = self._squeeze_bars + self._squeeze_bars = 0 + + if price > sma: + # Breakout upward + conv = min(0.5 + squeeze_duration * 0.1, 1.0) + return self._apply_filters( + Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + conviction=conv, + reason=f"Bollinger squeeze breakout UP after {squeeze_duration} bars", + ) + ) + else: + # Breakout downward + conv = min(0.5 + squeeze_duration * 0.1, 1.0) + return self._apply_filters( + Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + conviction=conv, + reason=f"Bollinger squeeze breakout DOWN after {squeeze_duration} bars", + ) + ) + # Bandwidth filter: skip sideways markets - if sma != 0 and (upper - lower) / sma < self._min_bandwidth: + if sma != 0 and bandwidth < self._min_bandwidth: return None - price = float(candle.close) - # Track band penetration if price < lower: self._was_below_lower = True @@ -90,14 +141,14 @@ class BollingerStrategy(BaseStrategy): # BUY: was below lower band and recovered back inside if self._was_below_lower and price >= lower: self._was_below_lower = False - conviction = self._bollinger_conviction(price, lower, sma) + conv = max(1.0 - pct_b, 0.3) # Closer to lower band = higher conviction signal = Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.BUY, price=candle.close, quantity=self._quantity, - conviction=conviction, + conviction=conv, reason=f"Price recovered above lower Bollinger Band ({lower:.2f})", ) return self._apply_filters(signal) @@ -105,14 +156,14 @@ class BollingerStrategy(BaseStrategy): # SELL: was above upper band and recovered back inside if self._was_above_upper and price <= upper: self._was_above_upper = False - conviction = self._bollinger_conviction(price, upper, sma) + conv = max(pct_b, 0.3) # Closer to upper band = higher conviction signal = Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.SELL, price=candle.close, quantity=self._quantity, - conviction=conviction, + conviction=conv, reason=f"Price recovered below upper Bollinger Band ({upper:.2f})", ) return self._apply_filters(signal) diff --git a/services/strategy-engine/strategies/combined_strategy.py b/services/strategy-engine/strategies/combined_strategy.py index be1cbed..ba92485 100644 --- a/services/strategy-engine/strategies/combined_strategy.py +++ b/services/strategy-engine/strategies/combined_strategy.py @@ -20,6 +20,9 @@ class CombinedStrategy(BaseStrategy): self._strategies: list[tuple[BaseStrategy, float]] = [] # (strategy, weight) self._threshold: float = 0.5 self._quantity: Decimal = Decimal("0.01") + self._trade_history: dict[str, list[bool]] = {} # strategy_name -> [win, loss, ...] + self._adaptive_weights: bool = False + self._history_window: int = 20 # Last N signals to evaluate @property def warmup_period(self) -> int: @@ -30,6 +33,8 @@ class CombinedStrategy(BaseStrategy): def configure(self, params: dict) -> None: self._threshold = float(params.get("threshold", 0.5)) self._quantity = Decimal(str(params.get("quantity", "0.01"))) + self._adaptive_weights = bool(params.get("adaptive_weights", False)) + self._history_window = int(params.get("history_window", 20)) if self._threshold <= 0: raise ValueError(f"Threshold must be positive, got {self._threshold}") if self._quantity <= 0: @@ -41,6 +46,31 @@ class CombinedStrategy(BaseStrategy): raise ValueError(f"Weight must be positive, got {weight}") self._strategies.append((strategy, weight)) + def record_result(self, strategy_name: str, is_win: bool) -> None: + """Record a trade result for adaptive weighting.""" + if strategy_name not in self._trade_history: + self._trade_history[strategy_name] = [] + self._trade_history[strategy_name].append(is_win) + # Keep only last N results + if len(self._trade_history[strategy_name]) > self._history_window: + self._trade_history[strategy_name] = self._trade_history[strategy_name][ + -self._history_window : + ] + + def _get_adaptive_weight(self, strategy_name: str, base_weight: float) -> float: + """Get weight adjusted by recent performance.""" + if not self._adaptive_weights: + return base_weight + + history = self._trade_history.get(strategy_name, []) + if len(history) < 5: # Not enough data, use base weight + return base_weight + + win_rate = sum(1 for w in history if w) / len(history) + # Scale weight: 0.5x at 20% win rate, 1.0x at 50%, 1.5x at 80% + scale = 0.5 + win_rate # Range: 0.5 to 1.5 + return base_weight * scale + def reset(self) -> None: for strategy, _ in self._strategies: strategy.reset() @@ -49,7 +79,7 @@ class CombinedStrategy(BaseStrategy): if not self._strategies: return None - total_weight = sum(w for _, w in self._strategies) + total_weight = sum(self._get_adaptive_weight(s.name, w) for s, w in self._strategies) if total_weight == 0: return None @@ -59,12 +89,17 @@ class CombinedStrategy(BaseStrategy): for strategy, weight in self._strategies: signal = strategy.on_candle(candle) if signal is not None: + effective_weight = self._get_adaptive_weight(strategy.name, weight) if signal.side == OrderSide.BUY: - score += weight * signal.conviction - reasons.append(f"{strategy.name}:BUY({weight}*{signal.conviction:.2f})") + score += effective_weight * signal.conviction + reasons.append( + f"{strategy.name}:BUY({effective_weight}*{signal.conviction:.2f})" + ) elif signal.side == OrderSide.SELL: - score -= weight * signal.conviction - reasons.append(f"{strategy.name}:SELL({weight}*{signal.conviction:.2f})") + score -= effective_weight * signal.conviction + reasons.append( + f"{strategy.name}:SELL({effective_weight}*{signal.conviction:.2f})" + ) normalized = score / total_weight # Range: -1.0 to 1.0 diff --git a/services/strategy-engine/strategies/config/grid_strategy.yaml b/services/strategy-engine/strategies/config/grid_strategy.yaml index 607f3df..338bb4c 100644 --- a/services/strategy-engine/strategies/config/grid_strategy.yaml +++ b/services/strategy-engine/strategies/config/grid_strategy.yaml @@ -1,4 +1,4 @@ -lower_price: 60000 -upper_price: 70000 +lower_price: 170 +upper_price: 190 grid_count: 5 -quantity: "0.01" +quantity: "1" diff --git a/services/strategy-engine/strategies/config/moc_strategy.yaml b/services/strategy-engine/strategies/config/moc_strategy.yaml new file mode 100644 index 0000000..349ae1b --- /dev/null +++ b/services/strategy-engine/strategies/config/moc_strategy.yaml @@ -0,0 +1,13 @@ +# Market on Close (MOC) Strategy — US Stocks +quantity_pct: 0.2 # 20% of capital per position +stop_loss_pct: 2.0 # -2% stop loss +rsi_min: 30 # RSI lower bound +rsi_max: 60 # RSI upper bound (not overbought) +ema_period: 20 # EMA for trend confirmation +volume_avg_period: 20 # Volume average lookback +min_volume_ratio: 1.0 # Volume must be >= average +buy_start_utc: 19 # Buy window start (15:00 ET summer) +buy_end_utc: 21 # Buy window end (16:00 ET) +sell_start_utc: 13 # Sell window start (9:00 ET) +sell_end_utc: 15 # Sell window end (10:00 ET) +max_positions: 5 # Max simultaneous positions diff --git a/services/strategy-engine/strategies/ema_crossover_strategy.py b/services/strategy-engine/strategies/ema_crossover_strategy.py index a812eff..68d0ba3 100644 --- a/services/strategy-engine/strategies/ema_crossover_strategy.py +++ b/services/strategy-engine/strategies/ema_crossover_strategy.py @@ -17,6 +17,9 @@ class EmaCrossoverStrategy(BaseStrategy): self._long_period: int = 21 self._quantity: Decimal = Decimal("0.01") self._prev_short_above: bool | None = None + self._pending_signal: str | None = None # "BUY" or "SELL" if waiting for pullback + self._pullback_enabled: bool = True + self._pullback_tolerance: float = 0.002 # 0.2% tolerance around short EMA @property def warmup_period(self) -> int: @@ -27,6 +30,9 @@ class EmaCrossoverStrategy(BaseStrategy): self._long_period = int(params.get("long_period", 21)) self._quantity = Decimal(str(params.get("quantity", "0.01"))) + self._pullback_enabled = bool(params.get("pullback_enabled", True)) + self._pullback_tolerance = float(params.get("pullback_tolerance", 0.002)) + if self._short_period >= self._long_period: raise ValueError( f"EMA short_period must be < long_period, " @@ -48,8 +54,10 @@ class EmaCrossoverStrategy(BaseStrategy): ) def reset(self) -> None: + super().reset() self._closes.clear() self._prev_short_above = None + self._pending_signal = None def _ema_conviction(self, short_ema: float, long_ema: float, price: float) -> float: """Map EMA gap to conviction (0.1-1.0). Larger gap = stronger crossover.""" @@ -70,33 +78,87 @@ class EmaCrossoverStrategy(BaseStrategy): short_ema = series.ewm(span=self._short_period, adjust=False).mean().iloc[-1] long_ema = series.ewm(span=self._long_period, adjust=False).mean().iloc[-1] + close = float(candle.close) short_above = short_ema > long_ema signal = None if self._prev_short_above is not None: - conviction = self._ema_conviction(short_ema, long_ema, float(candle.close)) - if not self._prev_short_above and short_above: + prev = self._prev_short_above + conviction = self._ema_conviction(short_ema, long_ema, close) + + # Golden Cross detected + if not prev and short_above: + if self._pullback_enabled: + self._pending_signal = "BUY" + # Don't signal yet — wait for pullback + else: + signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + conviction=conviction, + reason=f"Golden Cross: short EMA ({short_ema:.2f}) crossed above long EMA ({long_ema:.2f})", + ) + + # Death Cross detected + elif prev and not short_above: + if self._pullback_enabled: + self._pending_signal = "SELL" + else: + signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + conviction=conviction, + reason=f"Death Cross: short EMA ({short_ema:.2f}) crossed below long EMA ({long_ema:.2f})", + ) + + self._prev_short_above = short_above + + if signal is not None: + return self._apply_filters(signal) + + # Check for pullback entry + if self._pending_signal == "BUY": + distance = abs(close - short_ema) / short_ema if short_ema > 0 else 999 + if distance <= self._pullback_tolerance: + self._pending_signal = None + conv = min(0.5 + (1.0 - distance / self._pullback_tolerance) * 0.5, 1.0) signal = Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.BUY, price=candle.close, quantity=self._quantity, - conviction=conviction, - reason=f"Golden Cross: short EMA ({short_ema:.2f}) crossed above long EMA ({long_ema:.2f})", + conviction=conv, + reason=f"EMA Golden Cross pullback entry (distance={distance:.4f})", ) - elif self._prev_short_above and not short_above: + return self._apply_filters(signal) + # Cancel if crossover reverses + if not short_above: + self._pending_signal = None + + if self._pending_signal == "SELL": + distance = abs(close - short_ema) / short_ema if short_ema > 0 else 999 + if distance <= self._pullback_tolerance: + self._pending_signal = None + conv = min(0.5 + (1.0 - distance / self._pullback_tolerance) * 0.5, 1.0) signal = Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.SELL, price=candle.close, quantity=self._quantity, - conviction=conviction, - reason=f"Death Cross: short EMA ({short_ema:.2f}) crossed below long EMA ({long_ema:.2f})", + conviction=conv, + reason=f"EMA Death Cross pullback entry (distance={distance:.4f})", ) + return self._apply_filters(signal) + # Cancel if crossover reverses + if short_above: + self._pending_signal = None - self._prev_short_above = short_above - if signal is not None: - return self._apply_filters(signal) return None diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py index 70443ec..283bfe5 100644 --- a/services/strategy-engine/strategies/grid_strategy.py +++ b/services/strategy-engine/strategies/grid_strategy.py @@ -18,6 +18,9 @@ class GridStrategy(BaseStrategy): self._quantity: Decimal = Decimal("0.01") self._grid_levels: list[float] = [] self._last_zone: Optional[int] = None + self._exit_threshold_pct: float = 5.0 + self._out_of_range: bool = False + self._in_position: bool = False # Track if we have any grid positions @property def warmup_period(self) -> int: @@ -29,11 +32,15 @@ class GridStrategy(BaseStrategy): self._grid_count = int(params.get("grid_count", 5)) self._quantity = Decimal(str(params.get("quantity", "0.01"))) + self._exit_threshold_pct = float(params.get("exit_threshold_pct", 5.0)) + if self._lower_price >= self._upper_price: raise ValueError( f"Grid lower_price must be < upper_price, " f"got lower={self._lower_price}, upper={self._upper_price}" ) + if self._exit_threshold_pct <= 0: + raise ValueError(f"exit_threshold_pct must be > 0, got {self._exit_threshold_pct}") if self._grid_count < 2: raise ValueError(f"Grid grid_count must be >= 2, got {self._grid_count}") if self._quantity <= 0: @@ -53,7 +60,9 @@ class GridStrategy(BaseStrategy): ) def reset(self) -> None: + super().reset() self._last_zone = None + self._out_of_range = False def _get_zone(self, price: float) -> int: """Return the grid zone index for a given price. @@ -69,6 +78,31 @@ class GridStrategy(BaseStrategy): def on_candle(self, candle: Candle) -> Signal | None: self._update_filter_data(candle) price = float(candle.close) + + # Check if price is out of grid range + if self._grid_levels: + lower_bound = self._grid_levels[0] * (1 - self._exit_threshold_pct / 100) + upper_bound = self._grid_levels[-1] * (1 + self._exit_threshold_pct / 100) + + if price < lower_bound or price > upper_bound: + if not self._out_of_range: + self._out_of_range = True + # Exit signal — close positions + return self._apply_filters( + Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + conviction=0.8, + reason=f"Grid: price {price:.2f} broke out of range [{self._grid_levels[0]:.2f}, {self._grid_levels[-1]:.2f}]", + ) + ) + return None # Already out of range, no more signals + else: + self._out_of_range = False + current_zone = self._get_zone(price) if self._last_zone is None: diff --git a/services/strategy-engine/strategies/indicators/__init__.py b/services/strategy-engine/strategies/indicators/__init__.py index 1a54d59..3c713e6 100644 --- a/services/strategy-engine/strategies/indicators/__init__.py +++ b/services/strategy-engine/strategies/indicators/__init__.py @@ -1,12 +1,21 @@ """Reusable technical indicator functions.""" + from strategies.indicators.trend import ema, sma, macd, adx from strategies.indicators.volatility import atr, bollinger_bands, keltner_channels from strategies.indicators.momentum import rsi, stochastic from strategies.indicators.volume import volume_sma, volume_ratio, obv __all__ = [ - "ema", "sma", "macd", "adx", - "atr", "bollinger_bands", "keltner_channels", - "rsi", "stochastic", - "volume_sma", "volume_ratio", "obv", + "ema", + "sma", + "macd", + "adx", + "atr", + "bollinger_bands", + "keltner_channels", + "rsi", + "stochastic", + "volume_sma", + "volume_ratio", + "obv", ] diff --git a/services/strategy-engine/strategies/indicators/momentum.py b/services/strategy-engine/strategies/indicators/momentum.py index 395c52d..c479452 100644 --- a/services/strategy-engine/strategies/indicators/momentum.py +++ b/services/strategy-engine/strategies/indicators/momentum.py @@ -1,4 +1,5 @@ """Momentum indicators: RSI, Stochastic.""" + import pandas as pd import numpy as np diff --git a/services/strategy-engine/strategies/indicators/trend.py b/services/strategy-engine/strategies/indicators/trend.py index 10b69fa..c94a071 100644 --- a/services/strategy-engine/strategies/indicators/trend.py +++ b/services/strategy-engine/strategies/indicators/trend.py @@ -1,4 +1,5 @@ """Trend indicators: EMA, SMA, MACD, ADX.""" + import pandas as pd import numpy as np @@ -101,4 +102,4 @@ def adx( for i in range(2 * period + 1, n): adx_vals[i] = (adx_vals[i - 1] * (period - 1) + dx[i]) / period - return pd.Series(adx_vals, index=closes.index if hasattr(closes, 'index') else None) + return pd.Series(adx_vals, index=closes.index if hasattr(closes, "index") else None) diff --git a/services/strategy-engine/strategies/indicators/volatility.py b/services/strategy-engine/strategies/indicators/volatility.py index d47eb86..c16143e 100644 --- a/services/strategy-engine/strategies/indicators/volatility.py +++ b/services/strategy-engine/strategies/indicators/volatility.py @@ -1,4 +1,5 @@ """Volatility indicators: ATR, Bollinger Bands, Keltner Channels.""" + import pandas as pd import numpy as np @@ -30,7 +31,7 @@ def atr( for i in range(period, n): atr_vals[i] = (atr_vals[i - 1] * (period - 1) + tr[i]) / period - return pd.Series(atr_vals, index=closes.index if hasattr(closes, 'index') else None) + return pd.Series(atr_vals, index=closes.index if hasattr(closes, "index") else None) def bollinger_bands( @@ -62,6 +63,7 @@ def keltner_channels( Returns: (upper_channel, middle_ema, lower_channel) """ from strategies.indicators.trend import ema as calc_ema + middle = calc_ema(closes, ema_period) atr_vals = atr(highs, lows, closes, atr_period) upper = middle + atr_multiplier * atr_vals diff --git a/services/strategy-engine/strategies/indicators/volume.py b/services/strategy-engine/strategies/indicators/volume.py index 323d427..502f1ce 100644 --- a/services/strategy-engine/strategies/indicators/volume.py +++ b/services/strategy-engine/strategies/indicators/volume.py @@ -1,4 +1,5 @@ """Volume indicators: Volume SMA, Volume Ratio, OBV.""" + import pandas as pd import numpy as np diff --git a/services/strategy-engine/strategies/macd_strategy.py b/services/strategy-engine/strategies/macd_strategy.py index 67c5e44..356a42b 100644 --- a/services/strategy-engine/strategies/macd_strategy.py +++ b/services/strategy-engine/strategies/macd_strategy.py @@ -18,6 +18,8 @@ class MacdStrategy(BaseStrategy): self._quantity: Decimal = Decimal("0.01") self._closes: deque[float] = deque(maxlen=500) self._prev_histogram: float | None = None + self._prev_macd: float | None = None + self._prev_signal: float | None = None @property def warmup_period(self) -> int: @@ -54,6 +56,8 @@ class MacdStrategy(BaseStrategy): def reset(self) -> None: self._closes.clear() self._prev_histogram = None + self._prev_macd = None + self._prev_signal = None def _macd_conviction(self, histogram_value: float, price: float) -> float: """Map histogram magnitude to conviction (0.1-1.0). @@ -81,13 +85,45 @@ class MacdStrategy(BaseStrategy): histogram = macd_line - signal_line current_histogram = float(histogram.iloc[-1]) - signal = None + macd_val = float(macd_line.iloc[-1]) + signal_val = float(signal_line.iloc[-1]) + result_signal = None + + # Signal-line crossover detection (MACD crosses signal line directly) + if self._prev_macd is not None and self._prev_signal is not None: + # Bullish: MACD crosses above signal + if self._prev_macd <= self._prev_signal and macd_val > signal_val: + distance_from_zero = abs(macd_val) / float(candle.close) * 1000 + conv = min(max(distance_from_zero, 0.3), 1.0) + result_signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + conviction=conv, + reason="MACD signal-line bullish crossover", + ) + # Bearish: MACD crosses below signal + elif self._prev_macd >= self._prev_signal and macd_val < signal_val: + distance_from_zero = abs(macd_val) / float(candle.close) * 1000 + conv = min(max(distance_from_zero, 0.3), 1.0) + result_signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + conviction=conv, + reason="MACD signal-line bearish crossover", + ) - if self._prev_histogram is not None: + # Histogram crossover detection (existing logic, as secondary signal) + if result_signal is None and self._prev_histogram is not None: conviction = self._macd_conviction(current_histogram, float(candle.close)) # Bullish crossover: histogram crosses from negative to positive if self._prev_histogram <= 0 and current_histogram > 0: - signal = Signal( + result_signal = Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.BUY, @@ -98,7 +134,7 @@ class MacdStrategy(BaseStrategy): ) # Bearish crossover: histogram crosses from positive to negative elif self._prev_histogram >= 0 and current_histogram < 0: - signal = Signal( + result_signal = Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.SELL, @@ -109,6 +145,8 @@ class MacdStrategy(BaseStrategy): ) self._prev_histogram = current_histogram - if signal is not None: - return self._apply_filters(signal) + self._prev_macd = macd_val + self._prev_signal = signal_val + if result_signal is not None: + return self._apply_filters(result_signal) return None diff --git a/services/strategy-engine/strategies/moc_strategy.py b/services/strategy-engine/strategies/moc_strategy.py new file mode 100644 index 0000000..7eaa59e --- /dev/null +++ b/services/strategy-engine/strategies/moc_strategy.py @@ -0,0 +1,230 @@ +"""Market on Close (MOC) Strategy — US Stock 종가매매. + +Rules: +- Buy: 15:50-16:00 ET (market close) when screening criteria met +- Sell: 9:35-10:00 ET (market open next day) +- Screening: bullish candle, volume above average, RSI 30-60, positive momentum +- Risk: -2% stop loss, max 5 positions, 20% of capital per position +""" + +from collections import deque +from decimal import Decimal +from datetime import datetime + +import pandas as pd + +from shared.models import Candle, Signal, OrderSide +from strategies.base import BaseStrategy + + +class MocStrategy(BaseStrategy): + """Market on Close strategy for overnight gap trading.""" + + name: str = "moc" + + def __init__(self) -> None: + super().__init__() + # Parameters + self._quantity_pct: float = 0.2 # 20% of capital per trade + self._stop_loss_pct: float = 2.0 + self._rsi_min: float = 30.0 + self._rsi_max: float = 60.0 + self._ema_period: int = 20 + self._volume_avg_period: int = 20 + self._min_volume_ratio: float = 1.0 # Volume must be above average + # Session times (UTC hours) + self._buy_start_utc: int = 19 # 15:00 ET = 19:00 UTC (summer) / 20:00 UTC (winter) + self._buy_end_utc: int = 21 # 16:00 ET = 20:00 UTC / 21:00 UTC + self._sell_start_utc: int = 13 # 9:00 ET = 13:00 UTC / 14:00 UTC + self._sell_end_utc: int = 15 # 10:00 ET = 14:00 UTC / 15:00 UTC + self._max_positions: int = 5 + # State + self._closes: deque[float] = deque(maxlen=200) + self._volumes: deque[float] = deque(maxlen=200) + self._highs: deque[float] = deque(maxlen=200) + self._lows: deque[float] = deque(maxlen=200) + self._in_position: bool = False + self._entry_price: float = 0.0 + self._today: str | None = None + self._bought_today: bool = False + self._sold_today: bool = False + + @property + def warmup_period(self) -> int: + return max(self._ema_period, self._volume_avg_period) + 1 + + def configure(self, params: dict) -> None: + self._quantity_pct = float(params.get("quantity_pct", 0.2)) + self._stop_loss_pct = float(params.get("stop_loss_pct", 2.0)) + self._rsi_min = float(params.get("rsi_min", 30.0)) + self._rsi_max = float(params.get("rsi_max", 60.0)) + self._ema_period = int(params.get("ema_period", 20)) + self._volume_avg_period = int(params.get("volume_avg_period", 20)) + self._min_volume_ratio = float(params.get("min_volume_ratio", 1.0)) + self._buy_start_utc = int(params.get("buy_start_utc", 19)) + self._buy_end_utc = int(params.get("buy_end_utc", 21)) + self._sell_start_utc = int(params.get("sell_start_utc", 13)) + self._sell_end_utc = int(params.get("sell_end_utc", 15)) + self._max_positions = int(params.get("max_positions", 5)) + + if self._quantity_pct <= 0 or self._quantity_pct > 1: + raise ValueError(f"quantity_pct must be 0-1, got {self._quantity_pct}") + if self._stop_loss_pct <= 0: + raise ValueError(f"stop_loss_pct must be positive, got {self._stop_loss_pct}") + + def reset(self) -> None: + super().reset() + self._closes.clear() + self._volumes.clear() + self._highs.clear() + self._lows.clear() + self._in_position = False + self._entry_price = 0.0 + self._today = None + self._bought_today = False + self._sold_today = False + + def _is_buy_window(self, dt: datetime) -> bool: + """Check if in buy window (near market close).""" + hour = dt.hour + return self._buy_start_utc <= hour < self._buy_end_utc + + def _is_sell_window(self, dt: datetime) -> bool: + """Check if in sell window (near market open).""" + hour = dt.hour + return self._sell_start_utc <= hour < self._sell_end_utc + + def _compute_rsi(self, period: int = 14) -> float | None: + if len(self._closes) < period + 1: + return None + series = pd.Series(list(self._closes)) + delta = series.diff() + gain = delta.clip(lower=0) + loss = -delta.clip(upper=0) + avg_gain = gain.ewm(com=period - 1, min_periods=period).mean() + avg_loss = loss.ewm(com=period - 1, min_periods=period).mean() + rs = avg_gain / avg_loss.replace(0, float("nan")) + rsi = 100 - (100 / (1 + rs)) + val = rsi.iloc[-1] + return None if pd.isna(val) else float(val) + + def _is_bullish_candle(self, candle: Candle) -> bool: + return float(candle.close) > float(candle.open) + + def _price_above_ema(self) -> bool: + if len(self._closes) < self._ema_period: + return True + series = pd.Series(list(self._closes)) + ema = series.ewm(span=self._ema_period, adjust=False).mean().iloc[-1] + return self._closes[-1] >= ema + + def _volume_above_average(self) -> bool: + if len(self._volumes) < self._volume_avg_period: + return True + avg = sum(list(self._volumes)[-self._volume_avg_period :]) / self._volume_avg_period + return avg > 0 and self._volumes[-1] / avg >= self._min_volume_ratio + + def _positive_momentum(self) -> bool: + """Check if price has positive short-term momentum (close > close 5 bars ago).""" + if len(self._closes) < 6: + return True + return self._closes[-1] > self._closes[-6] + + def on_candle(self, candle: Candle) -> Signal | None: + self._update_filter_data(candle) + + close = float(candle.close) + self._closes.append(close) + self._volumes.append(float(candle.volume)) + self._highs.append(float(candle.high)) + self._lows.append(float(candle.low)) + + # Daily reset + day = candle.open_time.strftime("%Y-%m-%d") + if self._today != day: + self._today = day + self._bought_today = False + self._sold_today = False + + # --- SELL LOGIC (market open next day) --- + if self._in_position and self._is_sell_window(candle.open_time): + if not self._sold_today: + pnl_pct = (close - self._entry_price) / self._entry_price * 100 + self._in_position = False + self._sold_today = True + + conv = 0.8 if pnl_pct > 0 else 0.5 + return self._apply_filters( + Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=Decimal(str(self._quantity_pct)), + conviction=conv, + reason=f"MOC sell at open, PnL {pnl_pct:.2f}%", + ) + ) + + # --- STOP LOSS --- + if self._in_position: + pnl_pct = (close - self._entry_price) / self._entry_price * 100 + if pnl_pct <= -self._stop_loss_pct: + self._in_position = False + return self._apply_filters( + Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=Decimal(str(self._quantity_pct)), + conviction=1.0, + stop_loss=candle.close, + reason=f"MOC stop loss {pnl_pct:.2f}% <= -{self._stop_loss_pct}%", + ) + ) + + # --- BUY LOGIC (near market close) --- + if not self._in_position and self._is_buy_window(candle.open_time): + if self._bought_today: + return None + + # Screening criteria + rsi = self._compute_rsi() + if rsi is None: + return None + + checks = [ + self._rsi_min <= rsi <= self._rsi_max, # RSI in sweet spot + self._is_bullish_candle(candle), # Bullish candle + self._price_above_ema(), # Above EMA (uptrend) + self._volume_above_average(), # Volume confirmation + self._positive_momentum(), # Short-term momentum + ] + + if all(checks): + self._in_position = True + self._entry_price = close + self._bought_today = True + + # Conviction based on RSI position within range + rsi_range = self._rsi_max - self._rsi_min + rsi_pos = (rsi - self._rsi_min) / rsi_range if rsi_range > 0 else 0.5 + conv = 0.5 + (1.0 - rsi_pos) * 0.4 # Lower RSI = higher conviction + + sl = candle.close * (1 - Decimal(str(self._stop_loss_pct / 100))) + + return self._apply_filters( + Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=Decimal(str(self._quantity_pct)), + conviction=conv, + stop_loss=sl, + reason=f"MOC buy: RSI={rsi:.1f}, bullish candle, above EMA, vol OK", + ) + ) + + return None diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py index 0ec6780..0646d8c 100644 --- a/services/strategy-engine/strategies/rsi_strategy.py +++ b/services/strategy-engine/strategies/rsi_strategy.py @@ -34,6 +34,14 @@ class RsiStrategy(BaseStrategy): self._oversold: float = 30.0 self._overbought: float = 70.0 self._quantity: Decimal = Decimal("0.01") + # Divergence detection state + self._price_lows: deque[float] = deque(maxlen=5) + self._price_highs: deque[float] = deque(maxlen=5) + self._rsi_at_lows: deque[float] = deque(maxlen=5) + self._rsi_at_highs: deque[float] = deque(maxlen=5) + self._prev_close: float | None = None + self._prev_prev_close: float | None = None + self._prev_rsi: float | None = None @property def warmup_period(self) -> int: @@ -65,6 +73,13 @@ class RsiStrategy(BaseStrategy): def reset(self) -> None: self._closes.clear() + self._price_lows.clear() + self._price_highs.clear() + self._rsi_at_lows.clear() + self._rsi_at_highs.clear() + self._prev_close = None + self._prev_prev_close = None + self._prev_rsi = None def _rsi_conviction(self, rsi_value: float) -> float: """Map RSI value to conviction strength (0.0-1.0). @@ -86,14 +101,76 @@ class RsiStrategy(BaseStrategy): self._closes.append(float(candle.close)) if len(self._closes) < self._period + 1: + self._prev_prev_close = self._prev_close + self._prev_close = float(candle.close) return None series = pd.Series(list(self._closes)) rsi_value = _compute_rsi(series, self._period) if rsi_value is None: + self._prev_prev_close = self._prev_close + self._prev_close = float(candle.close) return None + close = float(candle.close) + + # Detect swing points for divergence + if self._prev_close is not None and self._prev_prev_close is not None: + # Swing low: prev_close < both neighbors + if self._prev_close < self._prev_prev_close and self._prev_close < close: + self._price_lows.append(self._prev_close) + self._rsi_at_lows.append( + self._prev_rsi if self._prev_rsi is not None else rsi_value + ) + # Swing high: prev_close > both neighbors + if self._prev_close > self._prev_prev_close and self._prev_close > close: + self._price_highs.append(self._prev_close) + self._rsi_at_highs.append( + self._prev_rsi if self._prev_rsi is not None else rsi_value + ) + + # Check bullish divergence: price lower low, RSI higher low + if len(self._price_lows) >= 2: + if ( + self._price_lows[-1] < self._price_lows[-2] + and self._rsi_at_lows[-1] > self._rsi_at_lows[-2] + ): + signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + conviction=0.9, + reason="RSI bullish divergence", + ) + self._prev_rsi = rsi_value + self._prev_prev_close = self._prev_close + self._prev_close = close + return self._apply_filters(signal) + + # Check bearish divergence: price higher high, RSI lower high + if len(self._price_highs) >= 2: + if ( + self._price_highs[-1] > self._price_highs[-2] + and self._rsi_at_highs[-1] < self._rsi_at_highs[-2] + ): + signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + conviction=0.9, + reason="RSI bearish divergence", + ) + self._prev_rsi = rsi_value + self._prev_prev_close = self._prev_close + self._prev_close = close + return self._apply_filters(signal) + + # Existing oversold/overbought logic (secondary signals) if rsi_value < self._oversold: signal = Signal( strategy=self.name, @@ -104,6 +181,9 @@ class RsiStrategy(BaseStrategy): conviction=self._rsi_conviction(rsi_value), reason=f"RSI {rsi_value:.2f} below oversold threshold {self._oversold}", ) + self._prev_rsi = rsi_value + self._prev_prev_close = self._prev_close + self._prev_close = close return self._apply_filters(signal) elif rsi_value > self._overbought: signal = Signal( @@ -115,6 +195,12 @@ class RsiStrategy(BaseStrategy): conviction=self._rsi_conviction(rsi_value), reason=f"RSI {rsi_value:.2f} above overbought threshold {self._overbought}", ) + self._prev_rsi = rsi_value + self._prev_prev_close = self._prev_close + self._prev_close = close return self._apply_filters(signal) + self._prev_rsi = rsi_value + self._prev_prev_close = self._prev_close + self._prev_close = close return None diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py index 324f1c2..ef2ae14 100644 --- a/services/strategy-engine/strategies/volume_profile_strategy.py +++ b/services/strategy-engine/strategies/volume_profile_strategy.py @@ -56,7 +56,8 @@ class VolumeProfileStrategy(BaseStrategy): self._was_below_va = False self._was_above_va = False - def _compute_value_area(self) -> tuple[float, float, float] | None: + def _compute_value_area(self) -> tuple[float, float, float, list[float], list[float]] | None: + """Compute POC, VA low, VA high, HVN levels, LVN levels.""" data = list(self._candles) if len(data) < self._lookback_period: return None @@ -67,7 +68,7 @@ class VolumeProfileStrategy(BaseStrategy): min_price = prices.min() max_price = prices.max() if min_price == max_price: - return (float(min_price), float(min_price), float(max_price)) + return (float(min_price), float(min_price), float(max_price), [], []) bin_edges = np.linspace(min_price, max_price, self._num_bins + 1) vol_profile = np.zeros(self._num_bins) @@ -84,7 +85,7 @@ class VolumeProfileStrategy(BaseStrategy): # Value Area: expand from POC outward total_volume = vol_profile.sum() if total_volume == 0: - return (poc, float(bin_edges[0]), float(bin_edges[-1])) + return (poc, float(bin_edges[0]), float(bin_edges[-1]), [], []) target_volume = self._value_area_pct * total_volume accumulated = vol_profile[poc_idx] @@ -111,7 +112,20 @@ class VolumeProfileStrategy(BaseStrategy): va_low = float(bin_edges[low_idx]) va_high = float(bin_edges[high_idx + 1]) - return (poc, va_low, va_high) + # HVN/LVN detection + mean_vol = vol_profile.mean() + std_vol = vol_profile.std() + + hvn_levels: list[float] = [] + lvn_levels: list[float] = [] + for i in range(len(vol_profile)): + mid = float((bin_edges[i] + bin_edges[i + 1]) / 2) + if vol_profile[i] > mean_vol + std_vol: + hvn_levels.append(mid) + elif vol_profile[i] < mean_vol - 0.5 * std_vol and vol_profile[i] > 0: + lvn_levels.append(mid) + + return (poc, va_low, va_high, hvn_levels, lvn_levels) def on_candle(self, candle: Candle) -> Signal | None: self._update_filter_data(candle) @@ -123,13 +137,41 @@ class VolumeProfileStrategy(BaseStrategy): if result is None: return None - poc, va_low, va_high = result + poc, va_low, va_high, hvn_levels, lvn_levels = result if close < va_low: self._was_below_va = True if close > va_high: self._was_above_va = True + # HVN bounce signals (stronger than regular VA bounces) + for hvn in hvn_levels: + if abs(close - hvn) / hvn < 0.005: # Within 0.5% of HVN + if self._was_below_va and close >= va_low: + self._was_below_va = False + signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + conviction=0.85, + reason=f"Price near HVN {hvn:.2f}, bounced from below VA low {va_low:.2f} to {close:.2f}", + ) + return self._apply_filters(signal) + if self._was_above_va and close <= va_high: + self._was_above_va = False + signal = Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + conviction=0.85, + reason=f"Price near HVN {hvn:.2f}, rejected from above VA high {va_high:.2f} to {close:.2f}", + ) + return self._apply_filters(signal) + # BUY: was below VA, price bounces back between va_low and poc if self._was_below_va and va_low <= close <= poc: self._was_below_va = False diff --git a/services/strategy-engine/strategies/vwap_strategy.py b/services/strategy-engine/strategies/vwap_strategy.py index c525ff3..d64950e 100644 --- a/services/strategy-engine/strategies/vwap_strategy.py +++ b/services/strategy-engine/strategies/vwap_strategy.py @@ -1,3 +1,4 @@ +from collections import deque from decimal import Decimal from shared.models import Candle, Signal, OrderSide @@ -16,6 +17,9 @@ class VwapStrategy(BaseStrategy): self._candle_count: int = 0 self._was_below_vwap: bool = False self._was_above_vwap: bool = False + self._current_date: str | None = None # Track date for daily reset + self._tp_values: deque[float] = deque(maxlen=500) # For std calculation + self._vwap_values: deque[float] = deque(maxlen=500) @property def warmup_period(self) -> int: @@ -41,11 +45,15 @@ class VwapStrategy(BaseStrategy): ) def reset(self) -> None: + super().reset() self._cumulative_tp_vol = 0.0 self._cumulative_vol = 0.0 self._candle_count = 0 self._was_below_vwap = False self._was_above_vwap = False + self._current_date = None + self._tp_values.clear() + self._vwap_values.clear() def _vwap_conviction(self, deviation: float) -> float: """Map VWAP deviation magnitude to conviction (0.1-1.0). @@ -58,6 +66,20 @@ class VwapStrategy(BaseStrategy): def on_candle(self, candle: Candle) -> Signal | None: self._update_filter_data(candle) + + # Daily reset + candle_date = candle.open_time.strftime("%Y-%m-%d") + if self._current_date is not None and candle_date != self._current_date: + # New day — reset VWAP + self._cumulative_tp_vol = 0.0 + self._cumulative_vol = 0.0 + self._candle_count = 0 + self._was_below_vwap = False + self._was_above_vwap = False + self._tp_values.clear() + self._vwap_values.clear() + self._current_date = candle_date + high = float(candle.high) low = float(candle.low) close = float(candle.close) @@ -77,6 +99,19 @@ class VwapStrategy(BaseStrategy): vwap = self._cumulative_tp_vol / self._cumulative_vol if vwap == 0.0: return None + + # Track values for deviation band calculation + self._tp_values.append(typical_price) + self._vwap_values.append(vwap) + + # Standard deviation of (TP - VWAP) for bands + std_dev = 0.0 + if len(self._tp_values) >= 2: + diffs = [tp - v for tp, v in zip(self._tp_values, self._vwap_values)] + mean_diff = sum(diffs) / len(diffs) + variance = sum((d - mean_diff) ** 2 for d in diffs) / len(diffs) + std_dev = variance**0.5 + deviation = (close - vwap) / vwap if deviation < -self._deviation_threshold: @@ -84,10 +119,20 @@ class VwapStrategy(BaseStrategy): if deviation > self._deviation_threshold: self._was_above_vwap = True + # Determine conviction based on deviation bands + def _band_conviction(price: float) -> float: + if std_dev > 0 and len(self._tp_values) >= 2: + dist_from_vwap = abs(price - vwap) + if dist_from_vwap >= 2 * std_dev: + return 0.9 + elif dist_from_vwap >= std_dev: + return 0.6 + return 0.5 + # Mean reversion from below: was below VWAP, now back near it if self._was_below_vwap and abs(deviation) <= self._deviation_threshold: self._was_below_vwap = False - conviction = self._vwap_conviction(deviation) + conviction = _band_conviction(close) signal = Signal( strategy=self.name, symbol=candle.symbol, @@ -102,7 +147,7 @@ class VwapStrategy(BaseStrategy): # Mean reversion from above: was above VWAP, now back near it if self._was_above_vwap and abs(deviation) <= self._deviation_threshold: self._was_above_vwap = False - conviction = self._vwap_conviction(deviation) + conviction = _band_conviction(close) signal = Signal( strategy=self.name, symbol=candle.symbol, diff --git a/services/strategy-engine/tests/test_base_filters.py b/services/strategy-engine/tests/test_base_filters.py index 97d9e16..ae9ca05 100644 --- a/services/strategy-engine/tests/test_base_filters.py +++ b/services/strategy-engine/tests/test_base_filters.py @@ -1,11 +1,12 @@ """Tests for BaseStrategy filters (ADX, volume, ATR stops).""" + import sys from pathlib import Path + sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from decimal import Decimal from datetime import datetime, timezone -import pytest from shared.models import Candle, Signal, OrderSide from strategies.base import BaseStrategy @@ -28,9 +29,12 @@ class DummyStrategy(BaseStrategy): def on_candle(self, candle: Candle) -> Signal | None: self._update_filter_data(candle) signal = Signal( - strategy=self.name, symbol=candle.symbol, - side=OrderSide.BUY, price=candle.close, - quantity=self._quantity, reason="test", + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + reason="test", ) return self._apply_filters(signal) @@ -39,10 +43,13 @@ def _candle(price=100.0, volume=10.0, high=None, low=None): h = high if high is not None else price + 5 lo = low if low is not None else price - 5 return Candle( - symbol="BTCUSDT", timeframe="1h", + symbol="AAPL", + timeframe="1h", open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), - open=Decimal(str(price)), high=Decimal(str(h)), - low=Decimal(str(lo)), close=Decimal(str(price)), + open=Decimal(str(price)), + high=Decimal(str(h)), + low=Decimal(str(lo)), + close=Decimal(str(price)), volume=Decimal(str(volume)), ) diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py index 348a9e0..8261377 100644 --- a/services/strategy-engine/tests/test_bollinger_strategy.py +++ b/services/strategy-engine/tests/test_bollinger_strategy.py @@ -10,7 +10,7 @@ from strategies.bollinger_strategy import BollingerStrategy def make_candle(close: float) -> Candle: return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal(str(close)), @@ -23,7 +23,7 @@ def make_candle(close: float) -> Candle: def _make_strategy() -> BollingerStrategy: s = BollingerStrategy() - s.configure({"period": 5, "num_std": 1.0, "min_bandwidth": 0.0}) + s.configure({"period": 5, "num_std": 1.0, "min_bandwidth": 0.0, "squeeze_threshold": 0.0}) return s @@ -99,3 +99,79 @@ def test_bollinger_reset_clears_state(): assert len(strategy._closes) == 1 assert strategy._was_below_lower is False assert strategy._was_above_upper is False + assert strategy._in_squeeze is False + assert strategy._squeeze_bars == 0 + + +def test_bollinger_squeeze_detection(): + """Tight bandwidth → no signal during squeeze.""" + # Use a strategy with a high squeeze threshold so constant prices trigger squeeze + s = BollingerStrategy() + s.configure( + { + "period": 5, + "num_std": 2.0, + "min_bandwidth": 0.0, + "squeeze_threshold": 0.5, # Very high threshold to ensure squeeze triggers + } + ) + + # Feed identical prices → bandwidth = 0 (below any threshold) + for _ in range(6): + result = s.on_candle(make_candle(100.0)) + + # With identical prices, std=0, bandwidth=0 < 0.5 → squeeze, no signal + assert s._in_squeeze is True + assert result is None + + +def test_bollinger_squeeze_breakout_buy(): + """Squeeze ends with price above SMA → BUY signal.""" + s = BollingerStrategy() + s.configure( + { + "period": 5, + "num_std": 1.0, + "min_bandwidth": 0.0, + "squeeze_threshold": 0.01, + } + ) + + # Feed identical prices to create a squeeze (bandwidth = 0) + for _ in range(6): + s.on_candle(make_candle(100.0)) + + assert s._in_squeeze is True + + # Now feed a price that creates enough spread to exit squeeze AND is above SMA + signal = s.on_candle(make_candle(120.0)) + assert signal is not None + assert signal.side == OrderSide.BUY + assert "squeeze breakout UP" in signal.reason + + +def test_bollinger_pct_b_conviction(): + """Signals near band extremes have higher conviction via %B.""" + s = BollingerStrategy() + s.configure( + { + "period": 5, + "num_std": 1.0, + "min_bandwidth": 0.0, + "squeeze_threshold": 0.0, # Disable squeeze for this test + } + ) + + # Build up with stable prices + for _ in range(5): + s.on_candle(make_candle(100.0)) + + # Drop below lower band + s.on_candle(make_candle(50.0)) + + # Recover just at the lower band edge — %B close to 0 → high conviction + signal = s.on_candle(make_candle(100.0)) + assert signal is not None + assert signal.side == OrderSide.BUY + # conviction = max(1.0 - pct_b, 0.3), with pct_b near lower → conviction should be >= 0.3 + assert signal.conviction >= 0.3 diff --git a/services/strategy-engine/tests/test_combined_strategy.py b/services/strategy-engine/tests/test_combined_strategy.py index 3408a89..8a4dc74 100644 --- a/services/strategy-engine/tests/test_combined_strategy.py +++ b/services/strategy-engine/tests/test_combined_strategy.py @@ -72,7 +72,7 @@ class NeutralStrategy(BaseStrategy): def _candle(price=100.0): return Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), open=Decimal(str(price)), @@ -167,3 +167,60 @@ def test_combined_invalid_weight(): c.configure({}) with pytest.raises(ValueError): c.add_strategy(AlwaysBuyStrategy(), weight=-1.0) + + +def test_combined_record_result(): + """Verify trade history tracking works correctly.""" + c = CombinedStrategy() + c.configure({"adaptive_weights": True, "history_window": 5}) + + c.record_result("test_strat", True) + c.record_result("test_strat", False) + c.record_result("test_strat", True) + + assert len(c._trade_history["test_strat"]) == 3 + assert c._trade_history["test_strat"] == [True, False, True] + + # Fill beyond window size to test trimming + for _ in range(5): + c.record_result("test_strat", False) + + assert len(c._trade_history["test_strat"]) == 5 # Trimmed to history_window + + +def test_combined_adaptive_weight_increases_for_winners(): + """Strategy with high win rate gets higher effective weight.""" + c = CombinedStrategy() + c.configure({"threshold": 0.3, "adaptive_weights": True, "history_window": 20}) + c.add_strategy(AlwaysBuyStrategy(), weight=1.0) + + # Record high win rate for always_buy (80% wins) + for _ in range(8): + c.record_result("always_buy", True) + for _ in range(2): + c.record_result("always_buy", False) + + # Adaptive weight should be > base weight (1.0) + adaptive_w = c._get_adaptive_weight("always_buy", 1.0) + assert adaptive_w > 1.0 + # 80% win rate -> scale = 0.5 + 0.8 = 1.3 -> weight = 1.3 + assert abs(adaptive_w - 1.3) < 0.01 + + +def test_combined_adaptive_weight_decreases_for_losers(): + """Strategy with low win rate gets lower effective weight.""" + c = CombinedStrategy() + c.configure({"threshold": 0.3, "adaptive_weights": True, "history_window": 20}) + c.add_strategy(AlwaysBuyStrategy(), weight=1.0) + + # Record low win rate for always_buy (20% wins) + for _ in range(2): + c.record_result("always_buy", True) + for _ in range(8): + c.record_result("always_buy", False) + + # Adaptive weight should be < base weight (1.0) + adaptive_w = c._get_adaptive_weight("always_buy", 1.0) + assert adaptive_w < 1.0 + # 20% win rate -> scale = 0.5 + 0.2 = 0.7 -> weight = 0.7 + assert abs(adaptive_w - 0.7) < 0.01 diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py index 0cf767b..7028eb0 100644 --- a/services/strategy-engine/tests/test_ema_crossover_strategy.py +++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py @@ -10,7 +10,7 @@ from strategies.ema_crossover_strategy import EmaCrossoverStrategy def make_candle(close: float) -> Candle: return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal(str(close)), @@ -21,9 +21,18 @@ def make_candle(close: float) -> Candle: ) -def _make_strategy(short: int = 3, long: int = 6) -> EmaCrossoverStrategy: +def _make_strategy( + short: int = 3, long: int = 6, pullback_enabled: bool = False +) -> EmaCrossoverStrategy: s = EmaCrossoverStrategy() - s.configure({"short_period": short, "long_period": long, "quantity": "0.01"}) + s.configure( + { + "short_period": short, + "long_period": long, + "quantity": "0.01", + "pullback_enabled": pullback_enabled, + } + ) return s @@ -97,3 +106,110 @@ def test_ema_reset_clears_state(): # Internal state should be cleared assert len(strategy._closes) == 1 assert strategy._prev_short_above is None + assert strategy._pending_signal is None + + +def test_ema_pullback_entry(): + """Crossover detected, then pullback to short EMA triggers signal.""" + strategy = EmaCrossoverStrategy() + strategy.configure( + { + "short_period": 3, + "long_period": 6, + "quantity": "0.01", + "pullback_enabled": True, + "pullback_tolerance": 0.05, # 5% tolerance for test simplicity + } + ) + + # Declining prices so short EMA stays below long EMA + declining = [100, 98, 96, 94, 92, 90, 88, 86, 84, 82] + for price in declining: + strategy.on_candle(make_candle(price)) + + # Sharp rise to force golden cross — with pullback enabled, no signal yet + rising = [120, 140, 160] + for price in rising: + strategy.on_candle(make_candle(price)) + + # With pullback enabled, crossover should NOT produce immediate signal + # but _pending_signal should be set + assert strategy._pending_signal == "BUY" + + # Now feed a candle whose close is near the short EMA (pullback) + # The short EMA will be tracking recent prices; feed a price that pulls back + # toward it. We use a moderate price to get close to short EMA. + import pandas as pd + + series = pd.Series(list(strategy._closes)) + short_ema_val = series.ewm(span=3, adjust=False).mean().iloc[-1] + # Feed a candle at approximately the short EMA value + result = strategy.on_candle(make_candle(short_ema_val)) + assert result is not None + assert result.side == OrderSide.BUY + assert "pullback" in result.reason + + +def test_ema_pullback_cancelled_on_reversal(): + """Crossover detected, then reversal cancels the pending signal.""" + strategy = EmaCrossoverStrategy() + strategy.configure( + { + "short_period": 3, + "long_period": 6, + "quantity": "0.01", + "pullback_enabled": True, + "pullback_tolerance": 0.001, # Very tight tolerance — won't trigger easily + } + ) + + # Declining prices + declining = [100, 98, 96, 94, 92, 90, 88, 86, 84, 82] + for price in declining: + strategy.on_candle(make_candle(price)) + + # Sharp rise to force golden cross + for price in [120, 140, 160]: + strategy.on_candle(make_candle(price)) + + assert strategy._pending_signal == "BUY" + + # Now sharp decline to reverse the crossover (death cross) + for price in [60, 40, 20]: + strategy.on_candle(make_candle(price)) + + # The BUY pending signal should be cancelled because short EMA fell below long EMA. + # A new death cross may set _pending_signal to "SELL", but the original "BUY" is gone. + assert strategy._pending_signal != "BUY" + + +def test_ema_immediate_mode(): + """With pullback_enabled=False, original immediate entry works.""" + strategy = EmaCrossoverStrategy() + strategy.configure( + { + "short_period": 3, + "long_period": 6, + "quantity": "0.01", + "pullback_enabled": False, + } + ) + + # Declining prices so short EMA stays below long EMA + declining = [100, 98, 96, 94, 92, 90, 88, 86, 84, 82] + for price in declining: + strategy.on_candle(make_candle(price)) + + # Sharp rise to force golden cross — immediate mode should fire signal + rising = [120, 140, 160] + signal = None + for price in rising: + result = strategy.on_candle(make_candle(price)) + if result is not None: + signal = result + + assert signal is not None + assert signal.side == OrderSide.BUY + assert "Golden Cross" in signal.reason + # No pending signal should be set + assert strategy._pending_signal is None diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py index ac9a596..2623027 100644 --- a/services/strategy-engine/tests/test_engine.py +++ b/services/strategy-engine/tests/test_engine.py @@ -13,7 +13,7 @@ from strategy_engine.engine import StrategyEngine def make_candle_event() -> dict: candle = Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal("50000"), @@ -28,7 +28,7 @@ def make_candle_event() -> dict: def make_signal() -> Signal: return Signal( strategy="test", - symbol="BTC/USDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50050"), quantity=Decimal("0.01"), @@ -46,12 +46,12 @@ async def test_engine_dispatches_candle_to_strategies(): strategy.on_candle = MagicMock(return_value=None) engine = StrategyEngine(broker=broker, strategies=[strategy]) - await engine.process_once("candles.BTC_USDT", "0") + await engine.process_once("candles.AAPL", "0") strategy.on_candle.assert_called_once() candle_arg = strategy.on_candle.call_args[0][0] assert isinstance(candle_arg, Candle) - assert candle_arg.symbol == "BTC/USDT" + assert candle_arg.symbol == "AAPL" @pytest.mark.asyncio @@ -64,7 +64,7 @@ async def test_engine_publishes_signal_when_strategy_returns_one(): strategy.on_candle = MagicMock(return_value=make_signal()) engine = StrategyEngine(broker=broker, strategies=[strategy]) - await engine.process_once("candles.BTC_USDT", "0") + await engine.process_once("candles.AAPL", "0") broker.publish.assert_called_once() call_args = broker.publish.call_args diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py index 79eb22a..878b900 100644 --- a/services/strategy-engine/tests/test_grid_strategy.py +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -10,7 +10,7 @@ from strategies.grid_strategy import GridStrategy def make_candle(close: float) -> Candle: return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal(str(close)), @@ -60,3 +60,41 @@ def test_grid_strategy_no_signal_in_same_zone(): strategy.on_candle(make_candle(50000)) signal = strategy.on_candle(make_candle(50100)) assert signal is None + + +def test_grid_exit_on_trend_break(): + """Price drops well below grid range → SELL signal emitted.""" + strategy = _configured_strategy() + # Grid range is 48000-52000, exit_threshold_pct defaults to 5% + # Lower bound = 48000 * 0.95 = 45600 + # Establish a zone first + strategy.on_candle(make_candle(50000)) + # Price drops far below the grid range + signal = strategy.on_candle(make_candle(45000)) + assert signal is not None + assert signal.side == OrderSide.SELL + assert "broke out of range" in signal.reason + + +def test_grid_no_signal_while_out_of_range(): + """After exit signal, no more grid signals until price returns to range.""" + strategy = _configured_strategy() + # Establish a zone + strategy.on_candle(make_candle(50000)) + # First out-of-range candle → SELL exit signal + signal = strategy.on_candle(make_candle(45000)) + assert signal is not None + assert signal.side == OrderSide.SELL + + # Subsequent out-of-range candles → no signals + signal = strategy.on_candle(make_candle(44000)) + assert signal is None + + signal = strategy.on_candle(make_candle(43000)) + assert signal is None + + # Price returns to grid range → grid signals resume + strategy.on_candle(make_candle(50000)) + signal = strategy.on_candle(make_candle(48100)) + assert signal is not None + assert signal.side == OrderSide.BUY diff --git a/services/strategy-engine/tests/test_indicators.py b/services/strategy-engine/tests/test_indicators.py index ac5b505..481569b 100644 --- a/services/strategy-engine/tests/test_indicators.py +++ b/services/strategy-engine/tests/test_indicators.py @@ -1,6 +1,8 @@ """Tests for technical indicator library.""" + import sys from pathlib import Path + sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import pandas as pd diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py index 9931b43..556fd4c 100644 --- a/services/strategy-engine/tests/test_macd_strategy.py +++ b/services/strategy-engine/tests/test_macd_strategy.py @@ -10,7 +10,7 @@ from strategies.macd_strategy import MacdStrategy def _candle(price: float) -> Candle: return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal(str(price)), @@ -78,3 +78,63 @@ def test_macd_reset_clears_state(): s.reset() assert len(s._closes) == 0 assert s._prev_histogram is None + assert s._prev_macd is None + assert s._prev_signal is None + + +def test_macd_signal_line_crossover(): + """Test that MACD signal-line crossover generates signals.""" + s = _make_strategy() + # Declining then rising prices should produce a signal-line bullish crossover + prices = [100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88] + prices += [89, 91, 94, 98, 103, 109, 116, 124, 133, 143] + signals = [] + for p in prices: + result = s.on_candle(_candle(float(p))) + if result is not None: + signals.append(result) + + buy_signals = [sig for sig in signals if sig.side == OrderSide.BUY] + assert len(buy_signals) > 0, "Expected at least one BUY signal" + # Check that at least one is a signal-line crossover or histogram crossover + all_reasons = [sig.reason for sig in buy_signals] + assert any("crossover" in r for r in all_reasons), ( + f"Expected crossover signal, got: {all_reasons}" + ) + + +def test_macd_conviction_varies_with_distance(): + """Test that conviction varies based on MACD distance from zero line.""" + s1 = _make_strategy() + s2 = _make_strategy() + + # Small price movements -> MACD near zero -> lower conviction + small_prices = [100, 99.5, 99, 98.5, 98, 97.5, 97, 96.5, 96, 95.5, 95, 94.5, 94] + small_prices += [94.5, 95, 95.5, 96, 96.5, 97, 97.5, 98, 98.5, 99] + small_signals = [] + for p in small_prices: + result = s1.on_candle(_candle(float(p))) + if result is not None: + small_signals.append(result) + + # Large price movements -> MACD far from zero -> higher conviction + large_prices = [100, 95, 90, 85, 80, 75, 70, 65, 60, 55, 50, 45, 40] + large_prices += [45, 55, 70, 90, 115, 145, 180, 220, 265, 315] + large_signals = [] + for p in large_prices: + result = s2.on_candle(_candle(float(p))) + if result is not None: + large_signals.append(result) + + # Both should produce signals + assert len(small_signals) > 0, "Expected signals from small movements" + assert len(large_signals) > 0, "Expected signals from large movements" + + # The large-movement signals should generally have higher conviction + # (or at least different conviction, since distance from zero affects it) + small_conv = small_signals[-1].conviction + large_conv = large_signals[-1].conviction + # Large movements should produce conviction >= small movements + assert large_conv >= small_conv, ( + f"Expected large movement conviction ({large_conv}) >= small ({small_conv})" + ) diff --git a/services/strategy-engine/tests/test_moc_strategy.py b/services/strategy-engine/tests/test_moc_strategy.py new file mode 100644 index 0000000..1928a28 --- /dev/null +++ b/services/strategy-engine/tests/test_moc_strategy.py @@ -0,0 +1,152 @@ +"""Tests for MOC (Market on Close) strategy.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from datetime import datetime, timezone +from decimal import Decimal + +from shared.models import Candle, OrderSide +from strategies.moc_strategy import MocStrategy + + +def _candle(price, hour=20, minute=0, volume=100.0, day=1, open_price=None): + op = open_price if open_price is not None else price - 1 # Default: bullish + return Candle( + symbol="AAPL", + timeframe="5Min", + open_time=datetime(2025, 1, day, hour, minute, tzinfo=timezone.utc), + open=Decimal(str(op)), + high=Decimal(str(price + 1)), + low=Decimal(str(min(op, price) - 1)), + close=Decimal(str(price)), + volume=Decimal(str(volume)), + ) + + +def _make_strategy(**overrides): + s = MocStrategy() + params = { + "quantity_pct": 0.2, + "stop_loss_pct": 2.0, + "rsi_min": 30, + "rsi_max": 70, # Wider for tests + "ema_period": 5, + "volume_avg_period": 5, + "min_volume_ratio": 0.5, + "buy_start_utc": 19, + "buy_end_utc": 21, + "sell_start_utc": 13, + "sell_end_utc": 15, + "max_positions": 5, + } + params.update(overrides) + s.configure(params) + return s + + +def test_moc_warmup_period(): + s = _make_strategy(ema_period=20, volume_avg_period=15) + assert s.warmup_period == 21 + + +def test_moc_no_signal_outside_buy_window(): + s = _make_strategy() + # Hour 12 UTC — not in buy (19-21) or sell (13-15) window + for i in range(10): + sig = s.on_candle(_candle(150 + i, hour=12, minute=i * 5)) + assert sig is None + + +def test_moc_buy_signal_in_window(): + s = _make_strategy(ema_period=3) + # Build up history with some oscillation so RSI settles in the 30-70 range + prices = [ + 150, + 149, + 151, + 148, + 152, + 149, + 150, + 151, + 148, + 150, + 149, + 151, + 150, + 152, + 151, + 153, + 152, + 154, + 153, + 155, + ] + signals = [] + for i, p in enumerate(prices): + sig = s.on_candle(_candle(p, hour=20, minute=i * 2, volume=200.0)) + if sig is not None: + signals.append(sig) + buy_signals = [sig for sig in signals if sig.side == OrderSide.BUY] + assert len(buy_signals) > 0 + assert buy_signals[0].strategy == "moc" + + +def test_moc_sell_at_open(): + s = _make_strategy(ema_period=3) + # Force entry + for i in range(10): + s.on_candle(_candle(150 + i, hour=20, minute=i * 3, volume=200.0)) + + if s._in_position: + # Next day, sell window + sig = s.on_candle(_candle(155, hour=14, minute=0, day=2)) + assert sig is not None + assert sig.side == OrderSide.SELL + assert "MOC sell" in sig.reason + + +def test_moc_stop_loss(): + s = _make_strategy(ema_period=3, stop_loss_pct=1.0) + for i in range(10): + s.on_candle(_candle(150 + i, hour=20, minute=i * 3, volume=200.0)) + + if s._in_position: + drop_price = s._entry_price * 0.98 # -2% + sig = s.on_candle(_candle(drop_price, hour=22, minute=0)) + if sig is not None: + assert sig.side == OrderSide.SELL + assert "stop loss" in sig.reason + + +def test_moc_no_buy_on_bearish_candle(): + s = _make_strategy(ema_period=3) + for i in range(8): + s.on_candle(_candle(150, hour=20, minute=i * 3, volume=200.0)) + # Bearish candle (open > close) + s.on_candle(_candle(149, hour=20, minute=30, open_price=151)) + # May or may not signal depending on other criteria, but bearish should reduce chances + # Just verify no crash + + +def test_moc_only_one_buy_per_day(): + s = _make_strategy(ema_period=3) + buy_count = 0 + for i in range(20): + sig = s.on_candle(_candle(150 + i * 0.3, hour=20, minute=i * 2, volume=200.0)) + if sig is not None and sig.side == OrderSide.BUY: + buy_count += 1 + assert buy_count <= 1 + + +def test_moc_reset(): + s = _make_strategy() + s.on_candle(_candle(150, hour=20)) + s._in_position = True + s.reset() + assert not s._in_position + assert len(s._closes) == 0 + assert not s._bought_today diff --git a/services/strategy-engine/tests/test_multi_symbol.py b/services/strategy-engine/tests/test_multi_symbol.py index cb8088c..671a9d3 100644 --- a/services/strategy-engine/tests/test_multi_symbol.py +++ b/services/strategy-engine/tests/test_multi_symbol.py @@ -22,7 +22,7 @@ async def test_engine_processes_multiple_streams(): broker = AsyncMock() candle_btc = Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), open=Decimal("50000"), @@ -32,7 +32,7 @@ async def test_engine_processes_multiple_streams(): volume=Decimal("10"), ) candle_eth = Candle( - symbol="ETHUSDT", + symbol="MSFT", timeframe="1m", open_time=datetime(2025, 1, 1, tzinfo=timezone.utc), open=Decimal("3000"), @@ -45,16 +45,16 @@ async def test_engine_processes_multiple_streams(): btc_events = [CandleEvent(data=candle_btc).to_dict()] eth_events = [CandleEvent(data=candle_eth).to_dict()] - # First call returns BTC event, second ETH, then empty - call_count = {"btc": 0, "eth": 0} + # First call returns AAPL event, second MSFT, then empty + call_count = {"aapl": 0, "msft": 0} async def mock_read(stream, **kwargs): - if "BTC" in stream: - call_count["btc"] += 1 - return btc_events if call_count["btc"] == 1 else [] - elif "ETH" in stream: - call_count["eth"] += 1 - return eth_events if call_count["eth"] == 1 else [] + if "AAPL" in stream: + call_count["aapl"] += 1 + return btc_events if call_count["aapl"] == 1 else [] + elif "MSFT" in stream: + call_count["msft"] += 1 + return eth_events if call_count["msft"] == 1 else [] return [] broker.read = AsyncMock(side_effect=mock_read) @@ -65,8 +65,8 @@ async def test_engine_processes_multiple_streams(): engine = StrategyEngine(broker=broker, strategies=[strategy]) # Process both streams - await engine.process_once("candles.BTCUSDT", "$") - await engine.process_once("candles.ETHUSDT", "$") + await engine.process_once("candles.AAPL", "$") + await engine.process_once("candles.MSFT", "$") # Strategy should have been called with both candles assert strategy.on_candle.call_count == 2 diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py index 2a2f4e7..6d31fd5 100644 --- a/services/strategy-engine/tests/test_rsi_strategy.py +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -10,7 +10,7 @@ from strategies.rsi_strategy import RsiStrategy def make_candle(close: float, idx: int = 0) -> Candle: return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal(str(close)), @@ -43,3 +43,60 @@ def test_rsi_strategy_buy_signal_on_oversold(): # if a signal is returned, it must be a BUY if signal is not None: assert signal.side == OrderSide.BUY + + +def test_rsi_detects_bullish_divergence(): + """Bullish divergence: price makes lower low, RSI makes higher low.""" + strategy = RsiStrategy() + strategy.configure({"period": 5, "oversold": 20, "overbought": 80}) + strategy._filter_enabled = False # Disable filters to test divergence logic only + + # Sharp consecutive drop to 50 drives RSI near 0 (first swing low). + # Big recovery, then gradual decline to 48 (lower price, but RSI > 0 = higher low). + prices = [100.0] * 7 + prices += [85.0, 70.0, 55.0, 50.0] + prices += [55.0, 65.0, 80.0, 95.0, 110.0, 120.0, 130.0, 135.0, 140.0, 142.0, 143.0, 144.0] + prices += [142.0, 140.0, 138.0, 135.0, 130.0, 125.0, 120.0, 115.0, 110.0, 105.0] + prices += [100.0, 95.0, 90.0, 85.0, 80.0, 75.0, 70.0, 65.0, 60.0, 55.0, 50.0, 48.0] + prices += [52.0, 58.0] + + signals = [] + for p in prices: + result = strategy.on_candle(make_candle(p)) + if result is not None: + signals.append(result) + + divergence_signals = [s for s in signals if "divergence" in s.reason] + assert len(divergence_signals) > 0, "Expected at least one bullish divergence signal" + assert divergence_signals[0].side == OrderSide.BUY + assert divergence_signals[0].conviction == 0.9 + assert "bullish divergence" in divergence_signals[0].reason + + +def test_rsi_detects_bearish_divergence(): + """Bearish divergence: price makes higher high, RSI makes lower high.""" + strategy = RsiStrategy() + strategy.configure({"period": 5, "oversold": 20, "overbought": 80}) + strategy._filter_enabled = False # Disable filters to test divergence logic only + + # Sharp consecutive rise to 160 drives RSI very high (first swing high). + # Deep pullback, then rise to 162 (higher price) but with a dip right before + # the peak to dampen RSI (lower high). + prices = [100.0] * 7 + prices += [110.0, 120.0, 130.0, 140.0, 150.0, 160.0] + prices += [155.0, 145.0, 130.0, 115.0, 100.0, 90.0, 80.0] + prices += [90.0, 100.0, 110.0, 120.0, 130.0, 140.0, 150.0] + prices += [145.0, 162.0] + prices += [155.0, 148.0] + + signals = [] + for p in prices: + result = strategy.on_candle(make_candle(p)) + if result is not None: + signals.append(result) + + divergence_signals = [s for s in signals if "divergence" in s.reason] + assert len(divergence_signals) > 0, "Expected at least one bearish divergence signal" + assert divergence_signals[0].side == OrderSide.SELL + assert divergence_signals[0].conviction == 0.9 + assert "bearish divergence" in divergence_signals[0].reason diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py index 71f0eca..65ee2e8 100644 --- a/services/strategy-engine/tests/test_volume_profile_strategy.py +++ b/services/strategy-engine/tests/test_volume_profile_strategy.py @@ -10,7 +10,7 @@ from strategies.volume_profile_strategy import VolumeProfileStrategy def make_candle(close: float, volume: float = 1.0) -> Candle: return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal(str(close)), @@ -125,3 +125,56 @@ def test_volume_profile_reset_clears_state(): # After reset, should not have enough data result = strategy.on_candle(make_candle(100.0, 10.0)) assert result is None + + +def test_volume_profile_hvn_detection(): + """Feed clustered volume at specific price levels to produce HVN nodes.""" + strategy = VolumeProfileStrategy() + strategy.configure({"lookback_period": 20, "num_bins": 10, "value_area_pct": 0.7}) + + # Create a profile with very high volume at price ~100 and low volume elsewhere + # Prices range from 90 to 110, heavy volume concentrated at 100 + candles_data = [] + # Low volume at extremes + for p in [90, 91, 92, 109, 110]: + candles_data.append((p, 1.0)) + # Very high volume around 100 + for _ in range(15): + candles_data.append((100, 100.0)) + + for price, vol in candles_data: + strategy.on_candle(make_candle(price, vol)) + + # Access the internal method to verify HVN detection + result = strategy._compute_value_area() + assert result is not None + poc, va_low, va_high, hvn_levels, lvn_levels = result + + # The bin containing price ~100 should have very high volume -> HVN + assert len(hvn_levels) > 0 + # At least one HVN should be near 100 + assert any(abs(h - 100) < 5 for h in hvn_levels) + + +def test_volume_profile_reset_thorough(): + """Verify all state is cleared on reset.""" + strategy = VolumeProfileStrategy() + strategy.configure({"lookback_period": 10, "num_bins": 5}) + + # Build up state + for _ in range(10): + strategy.on_candle(make_candle(100.0, 10.0)) + # Set below/above VA flags + strategy.on_candle(make_candle(50.0, 1.0)) # below VA + strategy.on_candle(make_candle(200.0, 1.0)) # above VA + + strategy.reset() + + # Verify all state cleared + assert len(strategy._candles) == 0 + assert strategy._was_below_va is False + assert strategy._was_above_va is False + + # Should not produce signal since no data + result = strategy.on_candle(make_candle(100.0, 10.0)) + assert result is None diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py index 5d76b04..2c34b01 100644 --- a/services/strategy-engine/tests/test_vwap_strategy.py +++ b/services/strategy-engine/tests/test_vwap_strategy.py @@ -13,15 +13,18 @@ def make_candle( high: float | None = None, low: float | None = None, volume: float = 1.0, + open_time: datetime | None = None, ) -> Candle: if high is None: high = close if low is None: low = close + if open_time is None: + open_time = datetime(2024, 1, 1, tzinfo=timezone.utc) return Candle( - symbol="BTC/USDT", + symbol="AAPL", timeframe="1m", - open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open_time=open_time, open=Decimal(str(close)), high=Decimal(str(high)), low=Decimal(str(low)), @@ -99,3 +102,46 @@ def test_vwap_reset_clears_state(): assert strategy._candle_count == 0 assert strategy._was_below_vwap is False assert strategy._was_above_vwap is False + assert strategy._current_date is None + assert len(strategy._tp_values) == 0 + assert len(strategy._vwap_values) == 0 + + +def test_vwap_daily_reset(): + """Candles from two different dates cause VWAP to reset.""" + strategy = _configured_strategy() + + day1 = datetime(2024, 1, 1, tzinfo=timezone.utc) + day2 = datetime(2024, 1, 2, tzinfo=timezone.utc) + + # Feed 35 candles on day 1 to build VWAP state + for i in range(35): + strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day1)) + + # Verify state is built up + assert strategy._candle_count == 35 + assert strategy._cumulative_vol > 0 + assert strategy._current_date == "2024-01-01" + + # Feed first candle of day 2 — should reset + strategy.on_candle(make_candle(100.0, high=101.0, low=99.0, open_time=day2)) + + # After reset, candle_count should be 1 (the new candle) + assert strategy._candle_count == 1 + assert strategy._current_date == "2024-01-02" + + +def test_vwap_reset_clears_date(): + """Verify reset() clears _current_date and deviation band state.""" + strategy = _configured_strategy() + + for _ in range(35): + strategy.on_candle(make_candle(100.0)) + + assert strategy._current_date is not None + + strategy.reset() + + assert strategy._current_date is None + assert len(strategy._tp_values) == 0 + assert len(strategy._vwap_values) == 0 |
