diff options
Diffstat (limited to 'services/strategy-engine')
17 files changed, 558 insertions, 0 deletions
diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile new file mode 100644 index 0000000..adecdd4 --- /dev/null +++ b/services/strategy-engine/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/strategy-engine/ services/strategy-engine/ +RUN pip install --no-cache-dir ./services/strategy-engine +CMD ["python", "-m", "strategy_engine.main"] diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml new file mode 100644 index 0000000..a86b282 --- /dev/null +++ b/services/strategy-engine/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "strategy-engine" +version = "0.1.0" +description = "Plugin-based strategy execution engine" +requires-python = ">=3.12" +dependencies = [ + "pandas>=2.0", + "trading-shared", +] + +[project.optional-dependencies] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/strategy_engine"] diff --git a/services/strategy-engine/src/strategy_engine/__init__.py b/services/strategy-engine/src/strategy_engine/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/__init__.py diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py new file mode 100644 index 0000000..2864b09 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -0,0 +1,8 @@ +"""Strategy Engine configuration.""" +from shared.config import Settings + + +class StrategyConfig(Settings): + symbols: list[str] = ["BTC/USDT"] + timeframes: list[str] = ["1m"] + strategy_params: dict = {} diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py new file mode 100644 index 0000000..09dbf65 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -0,0 +1,54 @@ +"""Strategy Engine: consumes candle events and publishes signals.""" +import logging + +from shared.broker import RedisBroker +from shared.events import CandleEvent, SignalEvent, Event + +from strategies.base import BaseStrategy + +logger = logging.getLogger(__name__) + + +class StrategyEngine: + def __init__(self, broker: RedisBroker, strategies: list[BaseStrategy]) -> None: + self._broker = broker + self._strategies = strategies + + async def process_once(self, stream: str, last_id: str) -> str: + """Read one batch of messages from the stream, process candles, publish signals. + + Returns the updated last_id for the next call. + """ + messages = await self._broker.read(stream, last_id=last_id, count=10, block=100) + + for raw in messages: + try: + event = Event.from_dict(raw) + except Exception as exc: + logger.warning("Failed to parse event: %s – %s", raw, exc) + continue + + if not isinstance(event, CandleEvent): + continue + + candle = event.data + for strategy in self._strategies: + try: + signal = strategy.on_candle(candle) + except Exception as exc: + logger.error( + "Strategy %s raised on candle: %s", strategy.name, exc + ) + continue + + if signal is not None: + signal_event = SignalEvent(data=signal) + await self._broker.publish("signals", signal_event.to_dict()) + logger.info( + "Signal published: strategy=%s symbol=%s side=%s", + signal.strategy, + signal.symbol, + signal.side, + ) + + return last_id diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py new file mode 100644 index 0000000..83bb867 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -0,0 +1,56 @@ +"""Strategy Engine Service entry point.""" +import asyncio +import logging +from pathlib import Path + +from shared.broker import RedisBroker + +from strategy_engine.config import StrategyConfig +from strategy_engine.engine import StrategyEngine +from strategy_engine.plugin_loader import load_strategies + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# The strategies directory lives alongside the installed package +STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies" + + +async def run() -> None: + config = StrategyConfig() + broker = RedisBroker(config.redis_url) + + strategies_dir = STRATEGIES_DIR + strategies = load_strategies(strategies_dir) + + # Configure each strategy with params from config + for strategy in strategies: + params = config.strategy_params.get(strategy.name, {}) + strategy.configure(params) + + logger.info( + "Loaded %d strategies: %s", + len(strategies), + [s.name for s in strategies], + ) + + engine = StrategyEngine(broker=broker, strategies=strategies) + + try: + for symbol in config.symbols: + stream = f"candles.{symbol.replace('/', '_')}" + last_id = "$" + logger.info("Starting engine loop for stream=%s", stream) + + while True: + last_id = await engine.process_once(stream, last_id) + finally: + await broker.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py new file mode 100644 index 0000000..719dc6d --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -0,0 +1,36 @@ +"""Dynamic plugin loader for strategy modules.""" +import importlib.util +import sys +from pathlib import Path + +from strategies.base import BaseStrategy + + +def load_strategies(strategies_dir: Path) -> list[BaseStrategy]: + """Scan strategies_dir for *.py files and load all BaseStrategy subclasses.""" + loaded: list[BaseStrategy] = [] + + for path in sorted(strategies_dir.glob("*.py")): + # Skip dunder files and base + if path.name.startswith("__") or path.name == "base.py": + continue + + module_name = f"_strategy_plugin_{path.stem}" + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + for attr_name in dir(module): + obj = getattr(module, attr_name) + if ( + isinstance(obj, type) + and issubclass(obj, BaseStrategy) + and obj is not BaseStrategy + ): + loaded.append(obj()) + + return loaded diff --git a/services/strategy-engine/strategies/__init__.py b/services/strategy-engine/strategies/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/strategy-engine/strategies/__init__.py diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py new file mode 100644 index 0000000..06101d0 --- /dev/null +++ b/services/strategy-engine/strategies/base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from shared.models import Candle, Signal + + +class BaseStrategy(ABC): + name: str = "base" + + @abstractmethod + def on_candle(self, candle: Candle) -> Signal | None: + pass + + @abstractmethod + def configure(self, params: dict) -> None: + pass + + def reset(self) -> None: + pass diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py new file mode 100644 index 0000000..f669f09 --- /dev/null +++ b/services/strategy-engine/strategies/grid_strategy.py @@ -0,0 +1,77 @@ +from decimal import Decimal +from typing import Optional + +import numpy as np + +from shared.models import Candle, Signal, OrderSide +from strategies.base import BaseStrategy + + +class GridStrategy(BaseStrategy): + name: str = "grid" + + def __init__(self) -> None: + self._lower_price: float = 0.0 + self._upper_price: float = 0.0 + self._grid_count: int = 5 + self._quantity: Decimal = Decimal("0.01") + self._grid_levels: list[float] = [] + self._last_zone: Optional[int] = None + + def configure(self, params: dict) -> None: + self._lower_price = float(params["lower_price"]) + self._upper_price = float(params["upper_price"]) + self._grid_count = int(params.get("grid_count", 5)) + self._quantity = Decimal(str(params.get("quantity", "0.01"))) + self._grid_levels = list( + np.linspace(self._lower_price, self._upper_price, self._grid_count + 1) + ) + self._last_zone = None + + def reset(self) -> None: + self._last_zone = None + + def _get_zone(self, price: float) -> int: + """Return the grid zone index for a given price. + + Zone 0 is below the lowest level, zone grid_count is above the highest level. + Zones 1..grid_count-1 are between levels. + """ + for i, level in enumerate(self._grid_levels): + if price < level: + return i + return len(self._grid_levels) + + def on_candle(self, candle: Candle) -> Signal | None: + price = float(candle.close) + current_zone = self._get_zone(price) + + if self._last_zone is None: + self._last_zone = current_zone + return None + + prev_zone = self._last_zone + self._last_zone = current_zone + + if current_zone < prev_zone: + # Price moved to a lower zone → BUY + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + reason=f"Grid: price crossed down from zone {prev_zone} to {current_zone}", + ) + elif current_zone > prev_zone: + # Price moved to a higher zone → SELL + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + reason=f"Grid: price crossed up from zone {prev_zone} to {current_zone}", + ) + + return None diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py new file mode 100644 index 0000000..aebbafc --- /dev/null +++ b/services/strategy-engine/strategies/rsi_strategy.py @@ -0,0 +1,77 @@ +from collections import deque +from decimal import Decimal + +import pandas as pd + +from shared.models import Candle, Signal, OrderSide +from strategies.base import BaseStrategy + + +def _compute_rsi(series: pd.Series, period: int) -> float | None: + """Compute RSI using Wilder's smoothing (EMA-based).""" + if len(series) < period + 1: + return None + delta = series.diff() + gain = delta.clip(lower=0) + loss = -delta.clip(upper=0) + avg_gain = gain.ewm(com=period - 1, min_periods=period).mean() + avg_loss = loss.ewm(com=period - 1, min_periods=period).mean() + rs = avg_gain / avg_loss.replace(0, float("nan")) + rsi = 100 - (100 / (1 + rs)) + value = rsi.iloc[-1] + if pd.isna(value): + return None + return float(value) + + +class RsiStrategy(BaseStrategy): + name: str = "rsi" + + def __init__(self) -> None: + self._closes: deque[float] = deque(maxlen=200) + self._period: int = 14 + self._oversold: float = 30.0 + self._overbought: float = 70.0 + self._quantity: Decimal = Decimal("0.01") + + def configure(self, params: dict) -> None: + self._period = int(params.get("period", 14)) + self._oversold = float(params.get("oversold", 30)) + self._overbought = float(params.get("overbought", 70)) + self._quantity = Decimal(str(params.get("quantity", "0.01"))) + + def reset(self) -> None: + self._closes.clear() + + def on_candle(self, candle: Candle) -> Signal | None: + self._closes.append(float(candle.close)) + + if len(self._closes) < self._period + 1: + return None + + series = pd.Series(list(self._closes)) + rsi_value = _compute_rsi(series, self._period) + + if rsi_value is None: + return None + + if rsi_value < self._oversold: + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + reason=f"RSI {rsi_value:.2f} below oversold threshold {self._oversold}", + ) + elif rsi_value > self._overbought: + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + reason=f"RSI {rsi_value:.2f} above overbought threshold {self._overbought}", + ) + + return None diff --git a/services/strategy-engine/tests/__init__.py b/services/strategy-engine/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/strategy-engine/tests/__init__.py diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py new file mode 100644 index 0000000..c9ef308 --- /dev/null +++ b/services/strategy-engine/tests/conftest.py @@ -0,0 +1,8 @@ +"""Pytest configuration: ensure strategies/ is importable.""" +import sys +from pathlib import Path + +# Add the strategies directory to sys.path so that `from strategies.base import ...` works +STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" +if str(STRATEGIES_DIR) not in sys.path: + sys.path.insert(0, str(STRATEGIES_DIR.parent)) diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py new file mode 100644 index 0000000..33ad4dd --- /dev/null +++ b/services/strategy-engine/tests/test_engine.py @@ -0,0 +1,72 @@ +"""Tests for the StrategyEngine.""" +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from shared.models import Candle, Signal, OrderSide +from shared.events import CandleEvent, SignalEvent +from strategy_engine.engine import StrategyEngine + + +def make_candle_event() -> dict: + candle = Candle( + symbol="BTC/USDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal("50000"), + high=Decimal("50100"), + low=Decimal("49900"), + close=Decimal("50050"), + volume=Decimal("10.0"), + ) + return CandleEvent(data=candle).to_dict() + + +def make_signal() -> Signal: + return Signal( + strategy="test", + symbol="BTC/USDT", + side=OrderSide.BUY, + price=Decimal("50050"), + quantity=Decimal("0.01"), + reason="test signal", + ) + + +@pytest.mark.asyncio +async def test_engine_dispatches_candle_to_strategies(): + broker = MagicMock() + broker.read = AsyncMock(return_value=[make_candle_event()]) + broker.publish = AsyncMock() + + strategy = MagicMock() + strategy.on_candle = MagicMock(return_value=None) + + engine = StrategyEngine(broker=broker, strategies=[strategy]) + await engine.process_once("candles.BTC_USDT", "0") + + strategy.on_candle.assert_called_once() + candle_arg = strategy.on_candle.call_args[0][0] + assert isinstance(candle_arg, Candle) + assert candle_arg.symbol == "BTC/USDT" + + +@pytest.mark.asyncio +async def test_engine_publishes_signal_when_strategy_returns_one(): + broker = MagicMock() + broker.read = AsyncMock(return_value=[make_candle_event()]) + broker.publish = AsyncMock() + + strategy = MagicMock() + strategy.on_candle = MagicMock(return_value=make_signal()) + + engine = StrategyEngine(broker=broker, strategies=[strategy]) + await engine.process_once("candles.BTC_USDT", "0") + + broker.publish.assert_called_once() + call_args = broker.publish.call_args + assert call_args[0][0] == "signals" + published_data = call_args[0][1] + assert published_data["type"] == "SIGNAL" diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py new file mode 100644 index 0000000..d96ebba --- /dev/null +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -0,0 +1,60 @@ +"""Tests for the Grid strategy.""" +from datetime import datetime, timezone +from decimal import Decimal + +import pytest + +from shared.models import Candle, OrderSide +from strategies.grid_strategy import GridStrategy + + +def make_candle(close: float) -> Candle: + return Candle( + symbol="BTC/USDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal(str(close)), + high=Decimal(str(close)), + low=Decimal(str(close)), + close=Decimal(str(close)), + volume=Decimal("1.0"), + ) + + +def _configured_strategy() -> GridStrategy: + strategy = GridStrategy() + strategy.configure({ + "lower_price": 48000, + "upper_price": 52000, + "grid_count": 5, + "quantity": "0.01", + }) + return strategy + + +def test_grid_strategy_buy_at_lower_grid(): + strategy = _configured_strategy() + # First candle: establish zone at upper area + strategy.on_candle(make_candle(51500)) + # Second candle: price drops to lower zone → BUY + signal = strategy.on_candle(make_candle(48100)) + assert signal is not None + assert signal.side == OrderSide.BUY + + +def test_grid_strategy_sell_at_upper_grid(): + strategy = _configured_strategy() + # First candle: establish zone at lower area + strategy.on_candle(make_candle(48100)) + # Second candle: price rises to upper zone → SELL + signal = strategy.on_candle(make_candle(51900)) + assert signal is not None + assert signal.side == OrderSide.SELL + + +def test_grid_strategy_no_signal_in_same_zone(): + strategy = _configured_strategy() + # Both candles in approximately the same zone + strategy.on_candle(make_candle(50000)) + signal = strategy.on_candle(make_candle(50100)) + assert signal is None diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py new file mode 100644 index 0000000..9496bab --- /dev/null +++ b/services/strategy-engine/tests/test_plugin_loader.py @@ -0,0 +1,22 @@ +"""Tests for the plugin loader.""" +from pathlib import Path + +import pytest + +from strategy_engine.plugin_loader import load_strategies + + +STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" + + +def test_load_strategies_finds_rsi_and_grid(): + strategies = load_strategies(STRATEGIES_DIR) + names = [s.name for s in strategies] + assert "rsi" in names + assert "grid" in names + + +def test_load_strategies_skips_base(): + strategies = load_strategies(STRATEGIES_DIR) + names = [s.name for s in strategies] + assert "base" not in names diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py new file mode 100644 index 0000000..90fface --- /dev/null +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -0,0 +1,45 @@ +"""Tests for the RSI strategy.""" +from datetime import datetime, timezone +from decimal import Decimal + +import pytest + +from shared.models import Candle, OrderSide +from strategies.rsi_strategy import RsiStrategy + + +def make_candle(close: float, idx: int = 0) -> Candle: + return Candle( + symbol="BTC/USDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal(str(close)), + high=Decimal(str(close)), + low=Decimal(str(close)), + close=Decimal(str(close)), + volume=Decimal("1.0"), + ) + + +def test_rsi_strategy_no_signal_insufficient_data(): + strategy = RsiStrategy() + strategy.configure({}) + candle = make_candle(50000.0) + result = strategy.on_candle(candle) + assert result is None + + +def test_rsi_strategy_buy_signal_on_oversold(): + strategy = RsiStrategy() + strategy.configure({"period": 14, "oversold": 30, "overbought": 70}) + + # Feed 20 steadily declining prices to force RSI into oversold territory + prices = [50000 - i * 500 for i in range(20)] + signal = None + for i, price in enumerate(prices): + signal = strategy.on_candle(make_candle(price, i)) + + # We may or may not get a signal depending on RSI calculation; + # if a signal is returned, it must be a BUY + if signal is not None: + assert signal.side == OrderSide.BUY |
