diff options
Diffstat (limited to 'services')
60 files changed, 2114 insertions, 0 deletions
diff --git a/services/backtester/Dockerfile b/services/backtester/Dockerfile new file mode 100644 index 0000000..77ec453 --- /dev/null +++ b/services/backtester/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/backtester/ services/backtester/ +RUN pip install --no-cache-dir ./services/backtester +CMD ["python", "-m", "backtester.main"] diff --git a/services/backtester/pyproject.toml b/services/backtester/pyproject.toml new file mode 100644 index 0000000..b51f913 --- /dev/null +++ b/services/backtester/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "backtester" +version = "0.1.0" +description = "Strategy backtesting engine" +requires-python = ">=3.12" +dependencies = ["pandas>=2.0", "trading-shared"] + +[project.optional-dependencies] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/backtester"] diff --git a/services/backtester/src/backtester/__init__.py b/services/backtester/src/backtester/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/backtester/src/backtester/__init__.py diff --git a/services/backtester/src/backtester/config.py b/services/backtester/src/backtester/config.py new file mode 100644 index 0000000..bfbc196 --- /dev/null +++ b/services/backtester/src/backtester/config.py @@ -0,0 +1,13 @@ +"""Configuration for the backtester service.""" +from pydantic_settings import BaseSettings + + +class BacktestConfig(BaseSettings): + backtest_initial_balance: float = 10000.0 + database_url: str = "postgresql://trading:trading@localhost:5432/trading" + symbol: str = "BTCUSDT" + timeframe: str = "1h" + strategy_name: str = "sma_crossover" + candle_limit: int = 500 + + model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} diff --git a/services/backtester/src/backtester/engine.py b/services/backtester/src/backtester/engine.py new file mode 100644 index 0000000..b89d422 --- /dev/null +++ b/services/backtester/src/backtester/engine.py @@ -0,0 +1,95 @@ +"""Backtesting engine that runs strategies against historical candle data.""" +from dataclasses import dataclass, field +from decimal import Decimal +from typing import Protocol + +from shared.models import Candle, Signal + +from backtester.simulator import OrderSimulator, SimulatedTrade + + +class StrategyProtocol(Protocol): + """Protocol matching BaseStrategy from strategy-engine.""" + + name: str + + def on_candle(self, candle: Candle) -> Signal | None: ... + + def configure(self, params: dict) -> None: ... + + def reset(self) -> None: ... + + +@dataclass +class BacktestResult: + strategy_name: str + symbol: str + total_trades: int + initial_balance: Decimal + final_balance: Decimal + profit: Decimal + profit_pct: Decimal + trades: list[SimulatedTrade] = field(default_factory=list) + + @property + def win_rate(self) -> float: + """Calculate win rate based on buy/sell pairs.""" + buy_prices: list[Decimal] = [] + wins = 0 + total_pairs = 0 + + for trade in self.trades: + if trade.side.value == "BUY": + buy_prices.append(trade.price) + else: + if buy_prices: + buy_price = buy_prices.pop(0) + total_pairs += 1 + if trade.price > buy_price: + wins += 1 + + if total_pairs == 0: + return 0.0 + return wins / total_pairs * 100 + + +class BacktestEngine: + """Runs a strategy against historical candles using a simulated order executor.""" + + def __init__(self, strategy: StrategyProtocol, initial_balance: Decimal) -> None: + self._strategy = strategy + self._initial_balance = initial_balance + + def run(self, candles: list[Candle]) -> BacktestResult: + """Run the backtest over a list of candles and return a result.""" + simulator = OrderSimulator(self._initial_balance) + + for candle in candles: + signal = self._strategy.on_candle(candle) + if signal is not None: + simulator.execute(signal) + + # Calculate final balance including open positions valued at last candle close + final_balance = simulator.balance + if candles: + last_price = candles[-1].close + for symbol, qty in simulator.positions.items(): + if qty > Decimal("0"): + final_balance += qty * last_price + + profit = final_balance - self._initial_balance + if self._initial_balance != Decimal("0"): + profit_pct = (profit / self._initial_balance) * Decimal("100") + else: + profit_pct = Decimal("0") + + return BacktestResult( + strategy_name=self._strategy.name, + symbol=candles[0].symbol if candles else "", + total_trades=len(simulator.trades), + initial_balance=self._initial_balance, + final_balance=final_balance, + profit=profit, + profit_pct=profit_pct, + trades=simulator.trades, + ) diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py new file mode 100644 index 0000000..ab69ee1 --- /dev/null +++ b/services/backtester/src/backtester/main.py @@ -0,0 +1,60 @@ +"""Main entry point for the backtester service.""" +import sys +import os +from decimal import Decimal + +# Allow importing strategies from the strategy-engine service +_STRATEGY_ENGINE_PATH = os.path.join( + os.path.dirname(__file__), "../../../../strategy-engine" +) +if _STRATEGY_ENGINE_PATH not in sys.path: + sys.path.insert(0, _STRATEGY_ENGINE_PATH) + +from shared.db import Database +from shared.models import Candle + +from backtester.config import BacktestConfig +from backtester.engine import BacktestEngine +from backtester.reporter import format_report + + +async def run_backtest() -> str: + """Load strategy, fetch candles, run backtest, and return a formatted report.""" + config = BacktestConfig() + + # Import strategy dynamically (requires strategy-engine in sys.path) + try: + from strategies.base import BaseStrategy # noqa: F401 + + # Try to import concrete strategy by name + module_name = config.strategy_name + import importlib + + mod = importlib.import_module(f"strategies.{module_name}") + strategy_cls = getattr(mod, "Strategy") + strategy = strategy_cls() + strategy.configure({}) + except Exception as exc: + raise RuntimeError( + f"Failed to load strategy '{config.strategy_name}': {exc}" + ) from exc + + db = Database(config.database_url) + await db.connect() + try: + rows = await db.get_candles(config.symbol, config.timeframe, config.candle_limit) + candles = [Candle(**row) for row in rows] + candles = list(reversed(candles)) # oldest first for strategy processing + finally: + await db.close() + + engine = BacktestEngine(strategy, Decimal(str(config.backtest_initial_balance))) + result = engine.run(candles) + return format_report(result) + + +if __name__ == "__main__": + import asyncio + + report = asyncio.run(run_backtest()) + print(report) diff --git a/services/backtester/src/backtester/reporter.py b/services/backtester/src/backtester/reporter.py new file mode 100644 index 0000000..916d5d4 --- /dev/null +++ b/services/backtester/src/backtester/reporter.py @@ -0,0 +1,28 @@ +"""Report formatting for backtest results.""" +from backtester.engine import BacktestResult + + +def format_report(result: BacktestResult) -> str: + """Format a backtest result into a human-readable text report.""" + separator = "=" * 50 + lines = [ + separator, + "BACKTEST REPORT", + separator, + f"Strategy: {result.strategy_name}", + f"Symbol: {result.symbol}", + separator, + "PERFORMANCE SUMMARY", + separator, + f"Initial Balance: {result.initial_balance:.2f}", + f"Final Balance: {result.final_balance:.2f}", + f"Profit/Loss: {result.profit:.2f}", + f"Profit %: {result.profit_pct:.2f}%", + separator, + "TRADE STATISTICS", + separator, + f"Total Trades: {result.total_trades}", + f"Win Rate: {result.win_rate:.2f}%", + separator, + ] + return "\n".join(lines) diff --git a/services/backtester/src/backtester/simulator.py b/services/backtester/src/backtester/simulator.py new file mode 100644 index 0000000..081ea3b --- /dev/null +++ b/services/backtester/src/backtester/simulator.py @@ -0,0 +1,54 @@ +"""Simulated order executor for backtesting.""" +from dataclasses import dataclass, field +from decimal import Decimal + +from shared.models import OrderSide, Signal + + +@dataclass +class SimulatedTrade: + symbol: str + side: OrderSide + price: Decimal + quantity: Decimal + balance_after: Decimal + + +class OrderSimulator: + """Simulates order execution against a paper balance.""" + + def __init__(self, initial_balance: Decimal) -> None: + self.balance: Decimal = initial_balance + self.positions: dict[str, Decimal] = {} + self.trades: list[SimulatedTrade] = [] + + def execute(self, signal: Signal) -> bool: + """Execute a signal. Returns True if the trade was accepted, False otherwise.""" + if signal.side == OrderSide.BUY: + cost = signal.price * signal.quantity + if cost > self.balance: + return False + self.balance -= cost + self.positions[signal.symbol] = ( + self.positions.get(signal.symbol, Decimal("0")) + signal.quantity + ) + trade_quantity = signal.quantity + else: # SELL + current_position = self.positions.get(signal.symbol, Decimal("0")) + if current_position <= Decimal("0"): + return False + trade_quantity = min(signal.quantity, current_position) + proceeds = signal.price * trade_quantity + self.balance += proceeds + self.positions[signal.symbol] = current_position - trade_quantity + + self.trades.append( + SimulatedTrade( + symbol=signal.symbol, + side=signal.side, + price=signal.price, + quantity=trade_quantity, + balance_after=self.balance, + ) + ) + return True diff --git a/services/backtester/tests/__init__.py b/services/backtester/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/backtester/tests/__init__.py diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py new file mode 100644 index 0000000..1a25e1c --- /dev/null +++ b/services/backtester/tests/test_engine.py @@ -0,0 +1,74 @@ +"""Tests for the BacktestEngine.""" +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from shared.models import Candle, Signal, OrderSide + +from backtester.engine import BacktestEngine, BacktestResult + + +def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle: + return Candle( + symbol=symbol, + timeframe=timeframe, + open_time=datetime.now(timezone.utc), + open=Decimal(str(price)), + high=Decimal(str(price * 1.01)), + low=Decimal(str(price * 0.99)), + close=Decimal(str(price)), + volume=Decimal("100"), + ) + + +def make_candles(prices: list[float], symbol: str = "BTCUSDT") -> list[Candle]: + return [make_candle(symbol, p) for p in prices] + + +def make_signal(side: OrderSide, price: str, quantity: str = "0.1") -> Signal: + return Signal( + strategy="test", + symbol="BTCUSDT", + side=side, + price=Decimal(price), + quantity=Decimal(quantity), + reason="test", + ) + + +def test_backtest_engine_runs_strategy_over_candles(): + strategy = MagicMock() + strategy.name = "mock_strategy" + strategy.on_candle.return_value = None + + candles = make_candles([50000.0, 51000.0, 52000.0]) + engine = BacktestEngine(strategy, Decimal("10000")) + result = engine.run(candles) + + assert strategy.on_candle.call_count == 3 + assert result.total_trades == 0 + assert result.final_balance == Decimal("10000") + assert result.strategy_name == "mock_strategy" + + +def test_backtest_engine_executes_signals(): + buy_signal = make_signal(OrderSide.BUY, "50000", "0.1") + sell_signal = make_signal(OrderSide.SELL, "55000", "0.1") + + strategy = MagicMock() + strategy.name = "mock_strategy" + strategy.on_candle.side_effect = [buy_signal, None, sell_signal] + + candles = make_candles([50000.0, 52000.0, 55000.0]) + engine = BacktestEngine(strategy, Decimal("10000")) + result = engine.run(candles) + + assert result.total_trades == 2 + # Initial: 10000, bought 0.1 BTC @ 50000 (cost 5000) → balance 5000 + # Sold 0.1 BTC @ 55000 (proceeds 5500) → balance 10500 + expected_final = Decimal("10500") + assert result.final_balance == expected_final + expected_profit = Decimal("500") + assert result.profit == expected_profit diff --git a/services/backtester/tests/test_reporter.py b/services/backtester/tests/test_reporter.py new file mode 100644 index 0000000..f5c694c --- /dev/null +++ b/services/backtester/tests/test_reporter.py @@ -0,0 +1,26 @@ +"""Tests for the report formatter.""" +from decimal import Decimal + +from backtester.engine import BacktestResult +from backtester.reporter import format_report + + +def test_format_report_contains_key_metrics(): + result = BacktestResult( + strategy_name="sma_crossover", + symbol="BTCUSDT", + total_trades=10, + initial_balance=Decimal("10000"), + final_balance=Decimal("11500"), + profit=Decimal("1500"), + profit_pct=Decimal("15"), + trades=[], + ) + report = format_report(result) + + assert "sma_crossover" in report + assert "BTCUSDT" in report + assert "10000" in report + assert "11500" in report + assert "1500" in report + assert "15" in report diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py new file mode 100644 index 0000000..9d8b23e --- /dev/null +++ b/services/backtester/tests/test_simulator.py @@ -0,0 +1,73 @@ +"""Tests for the OrderSimulator.""" +from decimal import Decimal + +import pytest + +from shared.models import Signal, OrderSide, OrderType +from backtester.simulator import OrderSimulator + + +def make_signal( + symbol: str, + side: OrderSide, + price: str, + quantity: str, + strategy: str = "test", +) -> Signal: + return Signal( + strategy=strategy, + symbol=symbol, + side=side, + price=Decimal(price), + quantity=Decimal(quantity), + reason="test", + ) + + +def test_simulator_initial_balance(): + sim = OrderSimulator(Decimal("10000")) + assert sim.balance == Decimal("10000") + + +def test_simulator_buy_reduces_balance(): + sim = OrderSimulator(Decimal("10000")) + signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + result = sim.execute(signal) + assert result is True + assert sim.balance == Decimal("5000") + assert sim.positions["BTCUSDT"] == Decimal("0.1") + + +def test_simulator_sell_increases_balance(): + sim = OrderSimulator(Decimal("10000")) + buy_signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + sim.execute(buy_signal) + balance_after_buy = sim.balance + + sell_signal = make_signal("BTCUSDT", OrderSide.SELL, "55000", "0.1") + result = sim.execute(sell_signal) + assert result is True + assert sim.balance > balance_after_buy + # Profit: sold at 55000, bought at 50000 → gain 500 + assert sim.balance == Decimal("10000") - Decimal("5000") + Decimal("5500") + + +def test_simulator_reject_buy_insufficient_balance(): + sim = OrderSimulator(Decimal("100")) + signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + result = sim.execute(signal) + assert result is False + assert sim.balance == Decimal("100") + assert sim.positions.get("BTCUSDT", Decimal("0")) == Decimal("0") + + +def test_simulator_trade_history(): + sim = OrderSimulator(Decimal("10000")) + signal = make_signal("BTCUSDT", OrderSide.BUY, "50000", "0.1") + sim.execute(signal) + assert len(sim.trades) == 1 + trade = sim.trades[0] + assert trade.symbol == "BTCUSDT" + assert trade.side == OrderSide.BUY + assert trade.price == Decimal("50000") + assert trade.quantity == Decimal("0.1") diff --git a/services/data-collector/Dockerfile b/services/data-collector/Dockerfile new file mode 100644 index 0000000..06f6d72 --- /dev/null +++ b/services/data-collector/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/data-collector/ services/data-collector/ +RUN pip install --no-cache-dir ./services/data-collector +CMD ["python", "-m", "data_collector.main"] diff --git a/services/data-collector/pyproject.toml b/services/data-collector/pyproject.toml new file mode 100644 index 0000000..5fba78f --- /dev/null +++ b/services/data-collector/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "data-collector" +version = "0.1.0" +description = "Binance market data collector service" +requires-python = ">=3.12" +dependencies = [ + "ccxt>=4.0", + "websockets>=12.0", + "trading-shared", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/data_collector"] diff --git a/services/data-collector/src/data_collector/__init__.py b/services/data-collector/src/data_collector/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/data-collector/src/data_collector/__init__.py diff --git a/services/data-collector/src/data_collector/binance_rest.py b/services/data-collector/src/data_collector/binance_rest.py new file mode 100644 index 0000000..af0eb77 --- /dev/null +++ b/services/data-collector/src/data_collector/binance_rest.py @@ -0,0 +1,53 @@ +"""Binance REST API helpers for fetching historical candle data.""" +from datetime import datetime, timezone +from decimal import Decimal + +from shared.models import Candle + + +def _normalize_symbol(symbol: str) -> str: + """Convert 'BTC/USDT' to 'BTCUSDT'.""" + return symbol.replace("/", "") + + +async def fetch_historical_candles( + exchange, + symbol: str, + timeframe: str, + since: int, + limit: int = 500, +) -> list[Candle]: + """Fetch historical OHLCV candles from the exchange and return Candle models. + + Args: + exchange: An async ccxt exchange instance. + symbol: Market symbol, e.g. 'BTC/USDT'. + timeframe: Candle timeframe, e.g. '1m'. + since: Start timestamp in milliseconds. + limit: Maximum number of candles to fetch. + + Returns: + A list of Candle model instances. + """ + rows = await exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) + + normalized = _normalize_symbol(symbol) + candles: list[Candle] = [] + + for row in rows: + ts_ms, o, h, l, c, v = row + open_time = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) + candles.append( + Candle( + symbol=normalized, + timeframe=timeframe, + open_time=open_time, + open=Decimal(str(o)), + high=Decimal(str(h)), + low=Decimal(str(l)), + close=Decimal(str(c)), + volume=Decimal(str(v)), + ) + ) + + return candles diff --git a/services/data-collector/src/data_collector/binance_ws.py b/services/data-collector/src/data_collector/binance_ws.py new file mode 100644 index 0000000..7a4bad2 --- /dev/null +++ b/services/data-collector/src/data_collector/binance_ws.py @@ -0,0 +1,106 @@ +"""Binance WebSocket client for real-time kline/candle data.""" +import asyncio +import json +import logging +from datetime import datetime, timezone +from decimal import Decimal +from typing import Callable, Awaitable + +import websockets + +from shared.models import Candle + +logger = logging.getLogger(__name__) + +BINANCE_WS_URL = "wss://stream.binance.com:9443/ws" +RECONNECT_DELAY = 5 # seconds + + +def _normalize_symbol(symbol: str) -> str: + """Convert 'BTC/USDT' to 'BTCUSDT'.""" + return symbol.replace("/", "") + + +def _stream_name(symbol: str, timeframe: str) -> str: + """Build Binance stream name, e.g. 'btcusdt@kline_1m'.""" + return f"{_normalize_symbol(symbol).lower()}@kline_{timeframe}" + + +class BinanceWebSocket: + """Connects to Binance WebSocket streams and emits closed candles.""" + + def __init__( + self, + symbols: list[str], + timeframe: str, + on_candle: Callable[[Candle], Awaitable[None]], + ) -> None: + self._symbols = symbols + self._timeframe = timeframe + self._on_candle = on_candle + self._running = False + + def _build_subscribe_message(self) -> dict: + streams = [_stream_name(s, self._timeframe) for s in self._symbols] + return { + "method": "SUBSCRIBE", + "params": streams, + "id": 1, + } + + def _parse_candle(self, message: dict) -> Candle | None: + """Parse a kline WebSocket message into a Candle, or None if not closed.""" + k = message.get("k") + if k is None: + return None + if not k.get("x"): # only closed candles + return None + + symbol = k["s"] # already normalized, e.g. 'BTCUSDT' + open_time = datetime.fromtimestamp(k["t"] / 1000, tz=timezone.utc) + return Candle( + symbol=symbol, + timeframe=self._timeframe, + open_time=open_time, + open=Decimal(k["o"]), + high=Decimal(k["h"]), + low=Decimal(k["l"]), + close=Decimal(k["c"]), + volume=Decimal(k["v"]), + ) + + async def _run_once(self) -> None: + """Single connection attempt; processes messages until disconnected.""" + async with websockets.connect(BINANCE_WS_URL) as ws: + subscribe_msg = self._build_subscribe_message() + await ws.send(json.dumps(subscribe_msg)) + logger.info("Subscribed to Binance streams: %s", subscribe_msg["params"]) + + async for raw in ws: + if not self._running: + break + try: + message = json.loads(raw) + candle = self._parse_candle(message) + if candle is not None: + await self._on_candle(candle) + except Exception: + logger.exception("Error processing WebSocket message: %s", raw) + + async def start(self) -> None: + """Connect to Binance WebSocket and process messages, auto-reconnecting.""" + self._running = True + while self._running: + try: + await self._run_once() + except Exception: + if not self._running: + break + logger.warning( + "WebSocket disconnected. Reconnecting in %ds…", RECONNECT_DELAY + ) + await asyncio.sleep(RECONNECT_DELAY) + + def stop(self) -> None: + """Signal the WebSocket loop to stop after the current message.""" + self._running = False diff --git a/services/data-collector/src/data_collector/config.py b/services/data-collector/src/data_collector/config.py new file mode 100644 index 0000000..1e080e5 --- /dev/null +++ b/services/data-collector/src/data_collector/config.py @@ -0,0 +1,6 @@ +from shared.config import Settings + + +class CollectorConfig(Settings): + symbols: list[str] = ["BTC/USDT"] + timeframes: list[str] = ["1m"] diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py new file mode 100644 index 0000000..adf1e96 --- /dev/null +++ b/services/data-collector/src/data_collector/main.py @@ -0,0 +1,58 @@ +"""Data Collector Service entry point.""" +import asyncio +import logging + +from shared.broker import RedisBroker +from shared.db import Database + +from data_collector.binance_ws import BinanceWebSocket +from data_collector.config import CollectorConfig +from data_collector.storage import CandleStorage + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def run() -> None: + """Initialise all components and start the WebSocket collector.""" + config = CollectorConfig() + + db = Database(config.database_url) + await db.connect() + await db.init_tables() + + broker = RedisBroker(config.redis_url) + storage = CandleStorage(db=db, broker=broker) + + async def on_candle(candle): + logger.info("Candle received: %s %s %s", candle.symbol, candle.timeframe, candle.open_time) + await storage.store(candle) + + # Use the first configured timeframe for the WebSocket subscription. + timeframe = config.timeframes[0] if config.timeframes else "1m" + + ws = BinanceWebSocket( + symbols=config.symbols, + timeframe=timeframe, + on_candle=on_candle, + ) + + logger.info( + "Starting data collector for symbols=%s timeframe=%s", + config.symbols, + timeframe, + ) + + try: + await ws.start() + finally: + await broker.close() + await db.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/data-collector/src/data_collector/storage.py b/services/data-collector/src/data_collector/storage.py new file mode 100644 index 0000000..1e40b82 --- /dev/null +++ b/services/data-collector/src/data_collector/storage.py @@ -0,0 +1,24 @@ +"""Candle storage: persists to DB and publishes to Redis.""" +from shared.events import CandleEvent +from shared.models import Candle + + +class CandleStorage: + """Stores candles in the database and publishes CandleEvents to Redis.""" + + def __init__(self, db, broker) -> None: + self._db = db + self._broker = broker + + async def store(self, candle: Candle) -> None: + """Insert candle into DB and publish a CandleEvent to the Redis stream.""" + await self._db.insert_candle(candle) + + event = CandleEvent(data=candle) + stream = f"candles.{candle.symbol}" + await self._broker.publish(stream, event.to_dict()) + + async def store_batch(self, candles: list[Candle]) -> None: + """Store multiple candles one by one.""" + for candle in candles: + await self.store(candle) diff --git a/services/data-collector/tests/__init__.py b/services/data-collector/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/data-collector/tests/__init__.py diff --git a/services/data-collector/tests/test_binance_rest.py b/services/data-collector/tests/test_binance_rest.py new file mode 100644 index 0000000..695dcf9 --- /dev/null +++ b/services/data-collector/tests/test_binance_rest.py @@ -0,0 +1,53 @@ +"""Tests for binance_rest module.""" +import pytest +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock +from datetime import datetime, timezone + +from data_collector.binance_rest import fetch_historical_candles + + +@pytest.mark.asyncio +async def test_fetch_historical_candles_parses_response(): + """Verify that OHLCV rows are correctly parsed into Candle models.""" + ts = 1700000000000 # milliseconds + mock_exchange = MagicMock() + mock_exchange.fetch_ohlcv = AsyncMock( + return_value=[ + [ts, 30000.0, 30100.0, 29900.0, 30050.0, 1.5], + [ts + 60000, 30050.0, 30200.0, 30000.0, 30150.0, 2.0], + ] + ) + + candles = await fetch_historical_candles( + mock_exchange, "BTC/USDT", "1m", since=ts, limit=500 + ) + + assert len(candles) == 2 + + c = candles[0] + assert c.symbol == "BTCUSDT" + assert c.timeframe == "1m" + assert c.open_time == datetime.fromtimestamp(ts / 1000, tz=timezone.utc) + assert c.open == Decimal("30000.0") + assert c.high == Decimal("30100.0") + assert c.low == Decimal("29900.0") + assert c.close == Decimal("30050.0") + assert c.volume == Decimal("1.5") + + mock_exchange.fetch_ohlcv.assert_called_once_with( + "BTC/USDT", "1m", since=ts, limit=500 + ) + + +@pytest.mark.asyncio +async def test_fetch_historical_candles_empty_response(): + """Verify that an empty exchange response returns an empty list.""" + mock_exchange = MagicMock() + mock_exchange.fetch_ohlcv = AsyncMock(return_value=[]) + + candles = await fetch_historical_candles( + mock_exchange, "BTC/USDT", "1m", since=1700000000000 + ) + + assert candles == [] diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py new file mode 100644 index 0000000..6b27414 --- /dev/null +++ b/services/data-collector/tests/test_storage.py @@ -0,0 +1,62 @@ +"""Tests for storage module.""" +import pytest +from decimal import Decimal +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from shared.models import Candle +from data_collector.storage import CandleStorage + + +def _make_candle(symbol: str = "BTCUSDT") -> Candle: + return Candle( + symbol=symbol, + timeframe="1m", + open_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + open=Decimal("30000"), + high=Decimal("30100"), + low=Decimal("29900"), + close=Decimal("30050"), + volume=Decimal("1.5"), + ) + + +@pytest.mark.asyncio +async def test_storage_saves_to_db_and_publishes(): + """Verify that store() calls insert_candle on db and publish on broker.""" + mock_db = MagicMock() + mock_db.insert_candle = AsyncMock() + mock_broker = MagicMock() + mock_broker.publish = AsyncMock() + + storage = CandleStorage(db=mock_db, broker=mock_broker) + candle = _make_candle() + + await storage.store(candle) + + mock_db.insert_candle.assert_called_once_with(candle) + mock_broker.publish.assert_called_once() + + stream_arg = mock_broker.publish.call_args[0][0] + assert stream_arg == "candles.BTCUSDT" + + data_arg = mock_broker.publish.call_args[0][1] + assert data_arg["type"] == "CANDLE" + assert data_arg["data"]["symbol"] == "BTCUSDT" + + +@pytest.mark.asyncio +async def test_storage_batch_store(): + """Verify that store_batch() calls store for each candle.""" + mock_db = MagicMock() + mock_db.insert_candle = AsyncMock() + mock_broker = MagicMock() + mock_broker.publish = AsyncMock() + + storage = CandleStorage(db=mock_db, broker=mock_broker) + candles = [_make_candle() for _ in range(3)] + + await storage.store_batch(candles) + + assert mock_db.insert_candle.call_count == 3 + assert mock_broker.publish.call_count == 3 diff --git a/services/order-executor/Dockerfile b/services/order-executor/Dockerfile new file mode 100644 index 0000000..f044714 --- /dev/null +++ b/services/order-executor/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/order-executor/ services/order-executor/ +RUN pip install --no-cache-dir ./services/order-executor +CMD ["python", "-m", "order_executor.main"] diff --git a/services/order-executor/pyproject.toml b/services/order-executor/pyproject.toml new file mode 100644 index 0000000..eed4fef --- /dev/null +++ b/services/order-executor/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "order-executor" +version = "0.1.0" +description = "Order execution service with risk management" +requires-python = ">=3.12" +dependencies = ["ccxt>=4.0", "trading-shared"] + +[project.optional-dependencies] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/order_executor"] diff --git a/services/order-executor/src/order_executor/__init__.py b/services/order-executor/src/order_executor/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/order-executor/src/order_executor/__init__.py diff --git a/services/order-executor/src/order_executor/config.py b/services/order-executor/src/order_executor/config.py new file mode 100644 index 0000000..856045f --- /dev/null +++ b/services/order-executor/src/order_executor/config.py @@ -0,0 +1,6 @@ +"""Order Executor configuration.""" +from shared.config import Settings + + +class ExecutorConfig(Settings): + pass diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py new file mode 100644 index 0000000..16ae52c --- /dev/null +++ b/services/order-executor/src/order_executor/executor.py @@ -0,0 +1,100 @@ +"""Order execution logic.""" +import logging +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Optional + +from shared.broker import RedisBroker +from shared.db import Database +from shared.events import OrderEvent +from shared.models import Order, OrderSide, OrderStatus, OrderType, Signal + +from order_executor.risk_manager import RiskManager + +logger = logging.getLogger(__name__) + + +class OrderExecutor: + """Executes orders on an exchange with risk gating.""" + + def __init__( + self, + exchange: Any, + risk_manager: RiskManager, + broker: RedisBroker, + db: Database, + dry_run: bool = True, + ) -> None: + self.exchange = exchange + self.risk_manager = risk_manager + self.broker = broker + self.db = db + self.dry_run = dry_run + + async def execute(self, signal: Signal) -> Optional[Order]: + """Run risk checks and place an order for the given signal.""" + # Fetch current balance from exchange + balance_data = await self.exchange.fetch_balance() + # Use USDT (or quote currency) free balance as available capital + free_balances = balance_data.get("free", {}) + quote_currency = signal.symbol.split("/")[-1] if "/" in signal.symbol else "USDT" + balance = Decimal(str(free_balances.get(quote_currency, 0))) + + # Fetch current positions + positions = {} + + # Compute daily PnL (not tracked at executor level — use 0 unless provided) + daily_pnl = Decimal(0) + + # Run risk checks + result = self.risk_manager.check( + signal=signal, + balance=balance, + positions=positions, + daily_pnl=daily_pnl, + ) + + if not result.allowed: + logger.warning( + "Risk check rejected signal %s: %s", signal.id, result.reason + ) + return None + + # Build the order model + order = Order( + signal_id=signal.id, + symbol=signal.symbol, + side=signal.side, + type=OrderType.MARKET, + price=signal.price, + quantity=signal.quantity, + status=OrderStatus.PENDING, + ) + + if self.dry_run: + order.status = OrderStatus.FILLED + order.filled_at = datetime.now(timezone.utc) + logger.info("[DRY RUN] Order filled: %s %s %s", order.side, order.quantity, order.symbol) + else: + try: + await self.exchange.create_order( + symbol=signal.symbol, + type="market", + side=signal.side.value.lower(), + amount=float(signal.quantity), + ) + order.status = OrderStatus.FILLED + order.filled_at = datetime.now(timezone.utc) + logger.info("Order filled: %s %s %s", order.side, order.quantity, order.symbol) + except Exception as exc: + order.status = OrderStatus.FAILED + logger.error("Order failed for signal %s: %s", signal.id, exc) + + # Persist to DB + await self.db.insert_order(order) + + # Publish order event + event = OrderEvent(data=order) + await self.broker.publish("orders", event.to_dict()) + + return order diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py new file mode 100644 index 0000000..b57c513 --- /dev/null +++ b/services/order-executor/src/order_executor/main.py @@ -0,0 +1,83 @@ +"""Order Executor Service entry point.""" +import asyncio +import logging +from decimal import Decimal + +import ccxt.async_support as ccxt + +from shared.broker import RedisBroker +from shared.db import Database +from shared.events import Event, EventType + +from order_executor.config import ExecutorConfig +from order_executor.executor import OrderExecutor +from order_executor.risk_manager import RiskManager + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def run() -> None: + config = ExecutorConfig() + logging.getLogger().setLevel(config.log_level) + + db = Database(config.database_url) + await db.connect() + await db.init_tables() + + broker = RedisBroker(config.redis_url) + + exchange = ccxt.binance( + { + "apiKey": config.binance_api_key, + "secret": config.binance_api_secret, + } + ) + + risk_manager = RiskManager( + max_position_size=Decimal(str(config.risk_max_position_size)), + stop_loss_pct=Decimal(str(config.risk_stop_loss_pct)), + daily_loss_limit_pct=Decimal(str(config.risk_daily_loss_limit_pct)), + ) + + executor = OrderExecutor( + exchange=exchange, + risk_manager=risk_manager, + broker=broker, + db=db, + dry_run=config.dry_run, + ) + + last_id = "$" + stream = "signals" + logger.info("Order executor started, listening on stream=%s dry_run=%s", stream, config.dry_run) + + try: + while True: + messages = await broker.read(stream, last_id=last_id, count=10, block=5000) + for msg in messages: + try: + event = Event.from_dict(msg) + if event.type == EventType.SIGNAL: + signal = event.data + logger.info("Processing signal %s for %s", signal.id, signal.symbol) + await executor.execute(signal) + except Exception as exc: + logger.error("Failed to process message: %s", exc) + if messages: + # Advance last_id to avoid re-reading — broker.read returns decoded dicts, + # so we track progress by re-reading with "0" for replaying or "$" for new only. + # Since we block on "$" we get only new messages each iteration. + pass + finally: + await broker.close() + await db.close() + await exchange.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py new file mode 100644 index 0000000..8e8a72c --- /dev/null +++ b/services/order-executor/src/order_executor/risk_manager.py @@ -0,0 +1,55 @@ +"""Risk management for order execution.""" +from dataclasses import dataclass +from decimal import Decimal + +from shared.models import Signal, OrderSide, Position + + +@dataclass +class RiskCheckResult: + allowed: bool + reason: str + + +class RiskManager: + """Evaluates risk before order execution.""" + + def __init__( + self, + max_position_size: Decimal, + stop_loss_pct: Decimal, + daily_loss_limit_pct: Decimal, + ) -> None: + self.max_position_size = max_position_size + self.stop_loss_pct = stop_loss_pct + self.daily_loss_limit_pct = daily_loss_limit_pct + + def check( + self, + signal: Signal, + balance: Decimal, + positions: dict[str, Position], + daily_pnl: Decimal, + ) -> RiskCheckResult: + """Run risk checks against a signal and current portfolio state.""" + # Check daily loss limit + if balance > 0 and (daily_pnl / balance) * 100 < -self.daily_loss_limit_pct: + return RiskCheckResult(allowed=False, reason="Daily loss limit exceeded") + + if signal.side == OrderSide.BUY: + order_cost = signal.price * signal.quantity + + # Check sufficient balance + if order_cost > balance: + return RiskCheckResult(allowed=False, reason="Insufficient balance") + + # Check position size limit + position = positions.get(signal.symbol) + current_position_value = Decimal(0) + if position is not None: + current_position_value = position.quantity * position.current_price + + if balance > 0 and (current_position_value + order_cost) / balance > self.max_position_size: + return RiskCheckResult(allowed=False, reason="Position size exceeded") + + return RiskCheckResult(allowed=True, reason="OK") diff --git a/services/order-executor/tests/__init__.py b/services/order-executor/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/order-executor/tests/__init__.py diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py new file mode 100644 index 0000000..5b18992 --- /dev/null +++ b/services/order-executor/tests/test_executor.py @@ -0,0 +1,122 @@ +"""Tests for OrderExecutor.""" +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from shared.models import OrderSide, OrderStatus, Signal +from order_executor.executor import OrderExecutor +from order_executor.risk_manager import RiskCheckResult, RiskManager + + +def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal: + return Signal( + strategy="test", + symbol="BTC/USDT", + side=side, + price=Decimal(price), + quantity=Decimal(quantity), + reason="test", + ) + + +def make_mock_exchange(free_usdt: float = 10000.0) -> AsyncMock: + exchange = AsyncMock() + exchange.fetch_balance.return_value = {"free": {"USDT": free_usdt}} + exchange.create_order = AsyncMock(return_value={"id": "exchange-order-123"}) + return exchange + + +def make_mock_risk_manager(allowed: bool = True, reason: str = "OK") -> MagicMock: + rm = MagicMock(spec=RiskManager) + rm.check.return_value = RiskCheckResult(allowed=allowed, reason=reason) + return rm + + +def make_mock_broker() -> AsyncMock: + broker = AsyncMock() + broker.publish = AsyncMock() + return broker + + +def make_mock_db() -> AsyncMock: + db = AsyncMock() + db.insert_order = AsyncMock() + return db + + +@pytest.mark.asyncio +async def test_executor_places_order_when_risk_passes(): + """When risk check passes, create_order is called and order status is FILLED.""" + exchange = make_mock_exchange() + risk_manager = make_mock_risk_manager(allowed=True) + broker = make_mock_broker() + db = make_mock_db() + + executor = OrderExecutor( + exchange=exchange, + risk_manager=risk_manager, + broker=broker, + db=db, + dry_run=False, + ) + + signal = make_signal() + order = await executor.execute(signal) + + assert order is not None + assert order.status == OrderStatus.FILLED + exchange.create_order.assert_called_once() + db.insert_order.assert_called_once_with(order) + broker.publish.assert_called_once() + + +@pytest.mark.asyncio +async def test_executor_rejects_when_risk_fails(): + """When risk check fails, create_order is not called and None is returned.""" + exchange = make_mock_exchange() + risk_manager = make_mock_risk_manager(allowed=False, reason="Position size exceeded") + broker = make_mock_broker() + db = make_mock_db() + + executor = OrderExecutor( + exchange=exchange, + risk_manager=risk_manager, + broker=broker, + db=db, + dry_run=False, + ) + + signal = make_signal() + order = await executor.execute(signal) + + assert order is None + exchange.create_order.assert_not_called() + db.insert_order.assert_not_called() + broker.publish.assert_not_called() + + +@pytest.mark.asyncio +async def test_executor_dry_run_does_not_call_exchange(): + """In dry-run mode, risk passes, order is FILLED, but exchange.create_order is NOT called.""" + exchange = make_mock_exchange() + risk_manager = make_mock_risk_manager(allowed=True) + broker = make_mock_broker() + db = make_mock_db() + + executor = OrderExecutor( + exchange=exchange, + risk_manager=risk_manager, + broker=broker, + db=db, + dry_run=True, + ) + + signal = make_signal() + order = await executor.execute(signal) + + assert order is not None + assert order.status == OrderStatus.FILLED + exchange.create_order.assert_not_called() + db.insert_order.assert_called_once_with(order) + broker.publish.assert_called_once() diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py new file mode 100644 index 0000000..f6b5545 --- /dev/null +++ b/services/order-executor/tests/test_risk_manager.py @@ -0,0 +1,72 @@ +"""Tests for RiskManager.""" +from decimal import Decimal + +import pytest + +from shared.models import OrderSide, Position, Signal +from order_executor.risk_manager import RiskManager + + +def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "BTC/USDT") -> Signal: + return Signal( + strategy="test", + symbol=symbol, + side=side, + price=Decimal(price), + quantity=Decimal(quantity), + reason="test signal", + ) + + +def make_risk_manager( + max_position_size: str = "0.1", + stop_loss_pct: str = "5.0", + daily_loss_limit_pct: str = "10.0", +) -> RiskManager: + return RiskManager( + max_position_size=Decimal(max_position_size), + stop_loss_pct=Decimal(stop_loss_pct), + daily_loss_limit_pct=Decimal(daily_loss_limit_pct), + ) + + +def test_risk_check_passes_normal_order(): + """Small BUY order with enough balance should be allowed.""" + rm = make_risk_manager() + signal = make_signal(side=OrderSide.BUY, price="100", quantity="0.5") + # cost = 50, balance = 10000, position_value = 0 => (0+50)/10000 = 0.5% < 10% + result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0")) + assert result.allowed is True + assert result.reason == "OK" + + +def test_risk_check_rejects_exceeding_position_size(): + """5 BTC at $50,000 = $250,000 order cost on $10,000,000 balance exceeds 10% limit.""" + rm = make_risk_manager(max_position_size="0.1") + signal = make_signal(side=OrderSide.BUY, price="50000", quantity="5") + # cost = 250000, balance = 1000000 => 250000/1000000 = 25% > 10% + # balance is sufficient (250000 < 1000000) but position size is exceeded + result = rm.check(signal, balance=Decimal("1000000"), positions={}, daily_pnl=Decimal("0")) + assert result.allowed is False + assert result.reason == "Position size exceeded" + + +def test_risk_check_rejects_daily_loss_exceeded(): + """Daily PnL of -1100 on 10000 balance = -11%, exceeding -10% limit.""" + rm = make_risk_manager(daily_loss_limit_pct="10.0") + signal = make_signal(side=OrderSide.BUY, price="100", quantity="0.1") + result = rm.check( + signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("-1100") + ) + assert result.allowed is False + assert result.reason == "Daily loss limit exceeded" + + +def test_risk_check_rejects_insufficient_balance(): + """Order cost of 500 exceeds available balance of 100.""" + rm = make_risk_manager() + signal = make_signal(side=OrderSide.BUY, price="100", quantity="5") + # cost = 500, balance = 100 + result = rm.check(signal, balance=Decimal("100"), positions={}, daily_pnl=Decimal("0")) + assert result.allowed is False + assert result.reason == "Insufficient balance" diff --git a/services/portfolio-manager/Dockerfile b/services/portfolio-manager/Dockerfile new file mode 100644 index 0000000..3f8587e --- /dev/null +++ b/services/portfolio-manager/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/portfolio-manager/ services/portfolio-manager/ +RUN pip install --no-cache-dir ./services/portfolio-manager +CMD ["python", "-m", "portfolio_manager.main"] diff --git a/services/portfolio-manager/pyproject.toml b/services/portfolio-manager/pyproject.toml new file mode 100644 index 0000000..8245aa0 --- /dev/null +++ b/services/portfolio-manager/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "portfolio-manager" +version = "0.1.0" +description = "Portfolio tracking and PnL calculation service" +requires-python = ">=3.12" +dependencies = ["trading-shared"] + +[project.optional-dependencies] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/portfolio_manager"] diff --git a/services/portfolio-manager/src/portfolio_manager/__init__.py b/services/portfolio-manager/src/portfolio_manager/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/portfolio-manager/src/portfolio_manager/__init__.py diff --git a/services/portfolio-manager/src/portfolio_manager/config.py b/services/portfolio-manager/src/portfolio_manager/config.py new file mode 100644 index 0000000..bbd5049 --- /dev/null +++ b/services/portfolio-manager/src/portfolio_manager/config.py @@ -0,0 +1,6 @@ +"""Portfolio Manager configuration.""" +from shared.config import Settings + + +class PortfolioConfig(Settings): + snapshot_interval_hours: int = 24 diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py new file mode 100644 index 0000000..cb7e6a8 --- /dev/null +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -0,0 +1,56 @@ +"""Portfolio Manager Service entry point.""" +import asyncio +import logging + +from shared.broker import RedisBroker +from shared.events import Event, OrderEvent + +from portfolio_manager.config import PortfolioConfig +from portfolio_manager.portfolio import PortfolioTracker + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +ORDERS_STREAM = "orders" + + +async def run() -> None: + config = PortfolioConfig() + broker = RedisBroker(config.redis_url) + tracker = PortfolioTracker() + + last_id = "$" + logger.info("Portfolio manager started, listening on stream=%s", ORDERS_STREAM) + + try: + while True: + messages = await broker.read(ORDERS_STREAM, last_id=last_id, block=1000) + for msg in messages: + try: + event = Event.from_dict(msg) + if isinstance(event, OrderEvent): + order = event.data + tracker.apply_order(order) + logger.info( + "Applied order symbol=%s side=%s qty=%s price=%s", + order.symbol, + order.side, + order.quantity, + order.price, + ) + positions = tracker.get_all_positions() + logger.info("Current positions count=%d", len(positions)) + except Exception: + logger.exception("Failed to process message: %s", msg) + # Update last_id to the latest processed message id if broker returns ids + # Since broker.read returns parsed payloads (not ids), we use "$" to get new msgs + finally: + await broker.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/portfolio-manager/src/portfolio_manager/pnl.py b/services/portfolio-manager/src/portfolio_manager/pnl.py new file mode 100644 index 0000000..96f0da8 --- /dev/null +++ b/services/portfolio-manager/src/portfolio_manager/pnl.py @@ -0,0 +1,21 @@ +"""PnL calculation functions for the portfolio manager.""" +from decimal import Decimal + + +def calculate_unrealized_pnl( + quantity: Decimal, + avg_entry_price: Decimal, + current_price: Decimal, +) -> Decimal: + """Calculate unrealized PnL for an open position.""" + return quantity * (current_price - avg_entry_price) + + +def calculate_realized_pnl( + buy_price: Decimal, + sell_price: Decimal, + quantity: Decimal, + fee: Decimal = Decimal("0"), +) -> Decimal: + """Calculate realized PnL for a completed trade.""" + return quantity * (sell_price - buy_price) - fee diff --git a/services/portfolio-manager/src/portfolio_manager/portfolio.py b/services/portfolio-manager/src/portfolio_manager/portfolio.py new file mode 100644 index 0000000..59106bb --- /dev/null +++ b/services/portfolio-manager/src/portfolio_manager/portfolio.py @@ -0,0 +1,62 @@ +"""Portfolio tracking for the portfolio manager service.""" +from decimal import Decimal + +from shared.models import Order, OrderSide, Position + + +class _PositionState: + """Internal state for tracking a single symbol's position.""" + + def __init__(self) -> None: + self.quantity: Decimal = Decimal("0") + self.avg_entry: Decimal = Decimal("0") + + +class PortfolioTracker: + """Tracks positions and updates them based on filled orders.""" + + def __init__(self) -> None: + self._positions: dict[str, _PositionState] = {} + + def _get_or_create(self, symbol: str) -> _PositionState: + if symbol not in self._positions: + self._positions[symbol] = _PositionState() + return self._positions[symbol] + + def apply_order(self, order: Order) -> None: + """Update internal position state based on a filled order.""" + state = self._get_or_create(order.symbol) + + if order.side == OrderSide.BUY: + # Weighted average entry price + total_cost = state.avg_entry * state.quantity + order.price * order.quantity + state.quantity += order.quantity + if state.quantity > Decimal("0"): + state.avg_entry = total_cost / state.quantity + elif order.side == OrderSide.SELL: + state.quantity -= order.quantity + # Keep avg_entry unchanged unless fully sold + if state.quantity <= Decimal("0"): + state.quantity = Decimal("0") + state.avg_entry = Decimal("0") + + def get_position(self, symbol: str) -> Position | None: + """Return a Position model for the symbol, or None if no/zero position.""" + state = self._positions.get(symbol) + if state is None or state.quantity <= Decimal("0"): + return None + return Position( + symbol=symbol, + quantity=state.quantity, + avg_entry_price=state.avg_entry, + current_price=state.avg_entry, # No live price here; caller can update + ) + + def get_all_positions(self) -> list[Position]: + """Return all non-zero positions.""" + positions = [] + for symbol in self._positions: + pos = self.get_position(symbol) + if pos is not None: + positions.append(pos) + return positions diff --git a/services/portfolio-manager/tests/__init__.py b/services/portfolio-manager/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/portfolio-manager/tests/__init__.py diff --git a/services/portfolio-manager/tests/test_pnl.py b/services/portfolio-manager/tests/test_pnl.py new file mode 100644 index 0000000..4462adc --- /dev/null +++ b/services/portfolio-manager/tests/test_pnl.py @@ -0,0 +1,32 @@ +"""Tests for PnL calculation functions.""" +from decimal import Decimal + +from portfolio_manager.pnl import calculate_realized_pnl, calculate_unrealized_pnl + + +def test_unrealized_pnl_profit() -> None: + result = calculate_unrealized_pnl( + quantity=Decimal("0.1"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("55000"), + ) + assert result == Decimal("500") + + +def test_unrealized_pnl_loss() -> None: + result = calculate_unrealized_pnl( + quantity=Decimal("0.1"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("45000"), + ) + assert result == Decimal("-500") + + +def test_realized_pnl_single_trade() -> None: + result = calculate_realized_pnl( + buy_price=Decimal("50000"), + sell_price=Decimal("55000"), + quantity=Decimal("0.1"), + fee=Decimal("5.5"), + ) + assert result == Decimal("494.5") diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py new file mode 100644 index 0000000..26319ca --- /dev/null +++ b/services/portfolio-manager/tests/test_portfolio.py @@ -0,0 +1,57 @@ +"""Tests for PortfolioTracker.""" +from decimal import Decimal + +from shared.models import Order, OrderSide, OrderStatus, OrderType +from portfolio_manager.portfolio import PortfolioTracker + + +def make_order(side: OrderSide, price: str, quantity: str) -> Order: + """Helper to create a filled Order.""" + return Order( + signal_id="test-signal", + symbol="BTC/USDT", + side=side, + type=OrderType.MARKET, + price=Decimal(price), + quantity=Decimal(quantity), + status=OrderStatus.FILLED, + ) + + +def test_portfolio_add_buy_order() -> None: + tracker = PortfolioTracker() + order = make_order(OrderSide.BUY, "50000", "0.1") + tracker.apply_order(order) + + position = tracker.get_position("BTC/USDT") + assert position is not None + assert position.quantity == Decimal("0.1") + assert position.avg_entry_price == Decimal("50000") + + +def test_portfolio_add_multiple_buys() -> None: + tracker = PortfolioTracker() + tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.1")) + tracker.apply_order(make_order(OrderSide.BUY, "52000", "0.1")) + + position = tracker.get_position("BTC/USDT") + assert position is not None + assert position.quantity == Decimal("0.2") + assert position.avg_entry_price == Decimal("51000") + + +def test_portfolio_sell_reduces_position() -> None: + tracker = PortfolioTracker() + tracker.apply_order(make_order(OrderSide.BUY, "50000", "0.2")) + tracker.apply_order(make_order(OrderSide.SELL, "55000", "0.1")) + + position = tracker.get_position("BTC/USDT") + assert position is not None + assert position.quantity == Decimal("0.1") + assert position.avg_entry_price == Decimal("50000") + + +def test_portfolio_no_position_returns_none() -> None: + tracker = PortfolioTracker() + position = tracker.get_position("ETH/USDT") + assert position is None diff --git a/services/strategy-engine/Dockerfile b/services/strategy-engine/Dockerfile new file mode 100644 index 0000000..adecdd4 --- /dev/null +++ b/services/strategy-engine/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.12-slim +WORKDIR /app +COPY shared/ shared/ +RUN pip install --no-cache-dir ./shared +COPY services/strategy-engine/ services/strategy-engine/ +RUN pip install --no-cache-dir ./services/strategy-engine +CMD ["python", "-m", "strategy_engine.main"] diff --git a/services/strategy-engine/pyproject.toml b/services/strategy-engine/pyproject.toml new file mode 100644 index 0000000..a86b282 --- /dev/null +++ b/services/strategy-engine/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "strategy-engine" +version = "0.1.0" +description = "Plugin-based strategy execution engine" +requires-python = ">=3.12" +dependencies = [ + "pandas>=2.0", + "trading-shared", +] + +[project.optional-dependencies] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/strategy_engine"] diff --git a/services/strategy-engine/src/strategy_engine/__init__.py b/services/strategy-engine/src/strategy_engine/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/__init__.py diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py new file mode 100644 index 0000000..2864b09 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -0,0 +1,8 @@ +"""Strategy Engine configuration.""" +from shared.config import Settings + + +class StrategyConfig(Settings): + symbols: list[str] = ["BTC/USDT"] + timeframes: list[str] = ["1m"] + strategy_params: dict = {} diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py new file mode 100644 index 0000000..09dbf65 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -0,0 +1,54 @@ +"""Strategy Engine: consumes candle events and publishes signals.""" +import logging + +from shared.broker import RedisBroker +from shared.events import CandleEvent, SignalEvent, Event + +from strategies.base import BaseStrategy + +logger = logging.getLogger(__name__) + + +class StrategyEngine: + def __init__(self, broker: RedisBroker, strategies: list[BaseStrategy]) -> None: + self._broker = broker + self._strategies = strategies + + async def process_once(self, stream: str, last_id: str) -> str: + """Read one batch of messages from the stream, process candles, publish signals. + + Returns the updated last_id for the next call. + """ + messages = await self._broker.read(stream, last_id=last_id, count=10, block=100) + + for raw in messages: + try: + event = Event.from_dict(raw) + except Exception as exc: + logger.warning("Failed to parse event: %s – %s", raw, exc) + continue + + if not isinstance(event, CandleEvent): + continue + + candle = event.data + for strategy in self._strategies: + try: + signal = strategy.on_candle(candle) + except Exception as exc: + logger.error( + "Strategy %s raised on candle: %s", strategy.name, exc + ) + continue + + if signal is not None: + signal_event = SignalEvent(data=signal) + await self._broker.publish("signals", signal_event.to_dict()) + logger.info( + "Signal published: strategy=%s symbol=%s side=%s", + signal.strategy, + signal.symbol, + signal.side, + ) + + return last_id diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py new file mode 100644 index 0000000..83bb867 --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -0,0 +1,56 @@ +"""Strategy Engine Service entry point.""" +import asyncio +import logging +from pathlib import Path + +from shared.broker import RedisBroker + +from strategy_engine.config import StrategyConfig +from strategy_engine.engine import StrategyEngine +from strategy_engine.plugin_loader import load_strategies + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# The strategies directory lives alongside the installed package +STRATEGIES_DIR = Path(__file__).parent.parent.parent.parent / "strategies" + + +async def run() -> None: + config = StrategyConfig() + broker = RedisBroker(config.redis_url) + + strategies_dir = STRATEGIES_DIR + strategies = load_strategies(strategies_dir) + + # Configure each strategy with params from config + for strategy in strategies: + params = config.strategy_params.get(strategy.name, {}) + strategy.configure(params) + + logger.info( + "Loaded %d strategies: %s", + len(strategies), + [s.name for s in strategies], + ) + + engine = StrategyEngine(broker=broker, strategies=strategies) + + try: + for symbol in config.symbols: + stream = f"candles.{symbol.replace('/', '_')}" + last_id = "$" + logger.info("Starting engine loop for stream=%s", stream) + + while True: + last_id = await engine.process_once(stream, last_id) + finally: + await broker.close() + + +def main() -> None: + asyncio.run(run()) + + +if __name__ == "__main__": + main() diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py new file mode 100644 index 0000000..719dc6d --- /dev/null +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -0,0 +1,36 @@ +"""Dynamic plugin loader for strategy modules.""" +import importlib.util +import sys +from pathlib import Path + +from strategies.base import BaseStrategy + + +def load_strategies(strategies_dir: Path) -> list[BaseStrategy]: + """Scan strategies_dir for *.py files and load all BaseStrategy subclasses.""" + loaded: list[BaseStrategy] = [] + + for path in sorted(strategies_dir.glob("*.py")): + # Skip dunder files and base + if path.name.startswith("__") or path.name == "base.py": + continue + + module_name = f"_strategy_plugin_{path.stem}" + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + for attr_name in dir(module): + obj = getattr(module, attr_name) + if ( + isinstance(obj, type) + and issubclass(obj, BaseStrategy) + and obj is not BaseStrategy + ): + loaded.append(obj()) + + return loaded diff --git a/services/strategy-engine/strategies/__init__.py b/services/strategy-engine/strategies/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/strategy-engine/strategies/__init__.py diff --git a/services/strategy-engine/strategies/base.py b/services/strategy-engine/strategies/base.py new file mode 100644 index 0000000..06101d0 --- /dev/null +++ b/services/strategy-engine/strategies/base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from shared.models import Candle, Signal + + +class BaseStrategy(ABC): + name: str = "base" + + @abstractmethod + def on_candle(self, candle: Candle) -> Signal | None: + pass + + @abstractmethod + def configure(self, params: dict) -> None: + pass + + def reset(self) -> None: + pass diff --git a/services/strategy-engine/strategies/grid_strategy.py b/services/strategy-engine/strategies/grid_strategy.py new file mode 100644 index 0000000..f669f09 --- /dev/null +++ b/services/strategy-engine/strategies/grid_strategy.py @@ -0,0 +1,77 @@ +from decimal import Decimal +from typing import Optional + +import numpy as np + +from shared.models import Candle, Signal, OrderSide +from strategies.base import BaseStrategy + + +class GridStrategy(BaseStrategy): + name: str = "grid" + + def __init__(self) -> None: + self._lower_price: float = 0.0 + self._upper_price: float = 0.0 + self._grid_count: int = 5 + self._quantity: Decimal = Decimal("0.01") + self._grid_levels: list[float] = [] + self._last_zone: Optional[int] = None + + def configure(self, params: dict) -> None: + self._lower_price = float(params["lower_price"]) + self._upper_price = float(params["upper_price"]) + self._grid_count = int(params.get("grid_count", 5)) + self._quantity = Decimal(str(params.get("quantity", "0.01"))) + self._grid_levels = list( + np.linspace(self._lower_price, self._upper_price, self._grid_count + 1) + ) + self._last_zone = None + + def reset(self) -> None: + self._last_zone = None + + def _get_zone(self, price: float) -> int: + """Return the grid zone index for a given price. + + Zone 0 is below the lowest level, zone grid_count is above the highest level. + Zones 1..grid_count-1 are between levels. + """ + for i, level in enumerate(self._grid_levels): + if price < level: + return i + return len(self._grid_levels) + + def on_candle(self, candle: Candle) -> Signal | None: + price = float(candle.close) + current_zone = self._get_zone(price) + + if self._last_zone is None: + self._last_zone = current_zone + return None + + prev_zone = self._last_zone + self._last_zone = current_zone + + if current_zone < prev_zone: + # Price moved to a lower zone → BUY + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + reason=f"Grid: price crossed down from zone {prev_zone} to {current_zone}", + ) + elif current_zone > prev_zone: + # Price moved to a higher zone → SELL + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + reason=f"Grid: price crossed up from zone {prev_zone} to {current_zone}", + ) + + return None diff --git a/services/strategy-engine/strategies/rsi_strategy.py b/services/strategy-engine/strategies/rsi_strategy.py new file mode 100644 index 0000000..aebbafc --- /dev/null +++ b/services/strategy-engine/strategies/rsi_strategy.py @@ -0,0 +1,77 @@ +from collections import deque +from decimal import Decimal + +import pandas as pd + +from shared.models import Candle, Signal, OrderSide +from strategies.base import BaseStrategy + + +def _compute_rsi(series: pd.Series, period: int) -> float | None: + """Compute RSI using Wilder's smoothing (EMA-based).""" + if len(series) < period + 1: + return None + delta = series.diff() + gain = delta.clip(lower=0) + loss = -delta.clip(upper=0) + avg_gain = gain.ewm(com=period - 1, min_periods=period).mean() + avg_loss = loss.ewm(com=period - 1, min_periods=period).mean() + rs = avg_gain / avg_loss.replace(0, float("nan")) + rsi = 100 - (100 / (1 + rs)) + value = rsi.iloc[-1] + if pd.isna(value): + return None + return float(value) + + +class RsiStrategy(BaseStrategy): + name: str = "rsi" + + def __init__(self) -> None: + self._closes: deque[float] = deque(maxlen=200) + self._period: int = 14 + self._oversold: float = 30.0 + self._overbought: float = 70.0 + self._quantity: Decimal = Decimal("0.01") + + def configure(self, params: dict) -> None: + self._period = int(params.get("period", 14)) + self._oversold = float(params.get("oversold", 30)) + self._overbought = float(params.get("overbought", 70)) + self._quantity = Decimal(str(params.get("quantity", "0.01"))) + + def reset(self) -> None: + self._closes.clear() + + def on_candle(self, candle: Candle) -> Signal | None: + self._closes.append(float(candle.close)) + + if len(self._closes) < self._period + 1: + return None + + series = pd.Series(list(self._closes)) + rsi_value = _compute_rsi(series, self._period) + + if rsi_value is None: + return None + + if rsi_value < self._oversold: + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.BUY, + price=candle.close, + quantity=self._quantity, + reason=f"RSI {rsi_value:.2f} below oversold threshold {self._oversold}", + ) + elif rsi_value > self._overbought: + return Signal( + strategy=self.name, + symbol=candle.symbol, + side=OrderSide.SELL, + price=candle.close, + quantity=self._quantity, + reason=f"RSI {rsi_value:.2f} above overbought threshold {self._overbought}", + ) + + return None diff --git a/services/strategy-engine/tests/__init__.py b/services/strategy-engine/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/strategy-engine/tests/__init__.py diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py new file mode 100644 index 0000000..c9ef308 --- /dev/null +++ b/services/strategy-engine/tests/conftest.py @@ -0,0 +1,8 @@ +"""Pytest configuration: ensure strategies/ is importable.""" +import sys +from pathlib import Path + +# Add the strategies directory to sys.path so that `from strategies.base import ...` works +STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" +if str(STRATEGIES_DIR) not in sys.path: + sys.path.insert(0, str(STRATEGIES_DIR.parent)) diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py new file mode 100644 index 0000000..33ad4dd --- /dev/null +++ b/services/strategy-engine/tests/test_engine.py @@ -0,0 +1,72 @@ +"""Tests for the StrategyEngine.""" +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from shared.models import Candle, Signal, OrderSide +from shared.events import CandleEvent, SignalEvent +from strategy_engine.engine import StrategyEngine + + +def make_candle_event() -> dict: + candle = Candle( + symbol="BTC/USDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal("50000"), + high=Decimal("50100"), + low=Decimal("49900"), + close=Decimal("50050"), + volume=Decimal("10.0"), + ) + return CandleEvent(data=candle).to_dict() + + +def make_signal() -> Signal: + return Signal( + strategy="test", + symbol="BTC/USDT", + side=OrderSide.BUY, + price=Decimal("50050"), + quantity=Decimal("0.01"), + reason="test signal", + ) + + +@pytest.mark.asyncio +async def test_engine_dispatches_candle_to_strategies(): + broker = MagicMock() + broker.read = AsyncMock(return_value=[make_candle_event()]) + broker.publish = AsyncMock() + + strategy = MagicMock() + strategy.on_candle = MagicMock(return_value=None) + + engine = StrategyEngine(broker=broker, strategies=[strategy]) + await engine.process_once("candles.BTC_USDT", "0") + + strategy.on_candle.assert_called_once() + candle_arg = strategy.on_candle.call_args[0][0] + assert isinstance(candle_arg, Candle) + assert candle_arg.symbol == "BTC/USDT" + + +@pytest.mark.asyncio +async def test_engine_publishes_signal_when_strategy_returns_one(): + broker = MagicMock() + broker.read = AsyncMock(return_value=[make_candle_event()]) + broker.publish = AsyncMock() + + strategy = MagicMock() + strategy.on_candle = MagicMock(return_value=make_signal()) + + engine = StrategyEngine(broker=broker, strategies=[strategy]) + await engine.process_once("candles.BTC_USDT", "0") + + broker.publish.assert_called_once() + call_args = broker.publish.call_args + assert call_args[0][0] == "signals" + published_data = call_args[0][1] + assert published_data["type"] == "SIGNAL" diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py new file mode 100644 index 0000000..d96ebba --- /dev/null +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -0,0 +1,60 @@ +"""Tests for the Grid strategy.""" +from datetime import datetime, timezone +from decimal import Decimal + +import pytest + +from shared.models import Candle, OrderSide +from strategies.grid_strategy import GridStrategy + + +def make_candle(close: float) -> Candle: + return Candle( + symbol="BTC/USDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal(str(close)), + high=Decimal(str(close)), + low=Decimal(str(close)), + close=Decimal(str(close)), + volume=Decimal("1.0"), + ) + + +def _configured_strategy() -> GridStrategy: + strategy = GridStrategy() + strategy.configure({ + "lower_price": 48000, + "upper_price": 52000, + "grid_count": 5, + "quantity": "0.01", + }) + return strategy + + +def test_grid_strategy_buy_at_lower_grid(): + strategy = _configured_strategy() + # First candle: establish zone at upper area + strategy.on_candle(make_candle(51500)) + # Second candle: price drops to lower zone → BUY + signal = strategy.on_candle(make_candle(48100)) + assert signal is not None + assert signal.side == OrderSide.BUY + + +def test_grid_strategy_sell_at_upper_grid(): + strategy = _configured_strategy() + # First candle: establish zone at lower area + strategy.on_candle(make_candle(48100)) + # Second candle: price rises to upper zone → SELL + signal = strategy.on_candle(make_candle(51900)) + assert signal is not None + assert signal.side == OrderSide.SELL + + +def test_grid_strategy_no_signal_in_same_zone(): + strategy = _configured_strategy() + # Both candles in approximately the same zone + strategy.on_candle(make_candle(50000)) + signal = strategy.on_candle(make_candle(50100)) + assert signal is None diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py new file mode 100644 index 0000000..9496bab --- /dev/null +++ b/services/strategy-engine/tests/test_plugin_loader.py @@ -0,0 +1,22 @@ +"""Tests for the plugin loader.""" +from pathlib import Path + +import pytest + +from strategy_engine.plugin_loader import load_strategies + + +STRATEGIES_DIR = Path(__file__).parent.parent / "strategies" + + +def test_load_strategies_finds_rsi_and_grid(): + strategies = load_strategies(STRATEGIES_DIR) + names = [s.name for s in strategies] + assert "rsi" in names + assert "grid" in names + + +def test_load_strategies_skips_base(): + strategies = load_strategies(STRATEGIES_DIR) + names = [s.name for s in strategies] + assert "base" not in names diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py new file mode 100644 index 0000000..90fface --- /dev/null +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -0,0 +1,45 @@ +"""Tests for the RSI strategy.""" +from datetime import datetime, timezone +from decimal import Decimal + +import pytest + +from shared.models import Candle, OrderSide +from strategies.rsi_strategy import RsiStrategy + + +def make_candle(close: float, idx: int = 0) -> Candle: + return Candle( + symbol="BTC/USDT", + timeframe="1m", + open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), + open=Decimal(str(close)), + high=Decimal(str(close)), + low=Decimal(str(close)), + close=Decimal(str(close)), + volume=Decimal("1.0"), + ) + + +def test_rsi_strategy_no_signal_insufficient_data(): + strategy = RsiStrategy() + strategy.configure({}) + candle = make_candle(50000.0) + result = strategy.on_candle(candle) + assert result is None + + +def test_rsi_strategy_buy_signal_on_oversold(): + strategy = RsiStrategy() + strategy.configure({"period": 14, "oversold": 30, "overbought": 70}) + + # Feed 20 steadily declining prices to force RSI into oversold territory + prices = [50000 - i * 500 for i in range(20)] + signal = None + for i, price in enumerate(prices): + signal = strategy.on_candle(make_candle(price, i)) + + # We may or may not get a signal depending on RSI calculation; + # if a signal is returned, it must be a BUY + if signal is not None: + assert signal.side == OrderSide.BUY |
