diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 18:23:46 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 18:23:46 +0900 |
| commit | 9e82c51dfde3941189db1b2d62dcc239442b9dc6 (patch) | |
| tree | 2c19068a4b2e381d44d22d662ec095794e144180 | |
| parent | cd3a06c7788ad8a747b1b4579fb6c45b6c43008e (diff) | |
feat(backtester): add walk-forward analysis engine
| -rw-r--r-- | services/backtester/src/backtester/walk_forward.py | 145 | ||||
| -rw-r--r-- | services/backtester/tests/test_walk_forward.py | 103 |
2 files changed, 248 insertions, 0 deletions
diff --git a/services/backtester/src/backtester/walk_forward.py b/services/backtester/src/backtester/walk_forward.py new file mode 100644 index 0000000..fe6d020 --- /dev/null +++ b/services/backtester/src/backtester/walk_forward.py @@ -0,0 +1,145 @@ +"""Walk-forward analysis for strategy parameter optimization.""" +from dataclasses import dataclass, field +from decimal import Decimal +from typing import Callable + +from shared.models import Candle +from backtester.engine import BacktestEngine, BacktestResult, StrategyProtocol + + +@dataclass +class WalkForwardWindow: + """Result for a single in-sample/out-of-sample window.""" + window_index: int + in_sample_result: BacktestResult + out_of_sample_result: BacktestResult + best_params: dict + + +@dataclass +class WalkForwardResult: + """Aggregated walk-forward analysis results.""" + strategy_name: str + symbol: str + num_windows: int + windows: list[WalkForwardWindow] = field(default_factory=list) + + @property + def out_of_sample_profit_pct(self) -> float: + """Combined out-of-sample profit percentage.""" + if not self.windows: + return 0.0 + total_profit = sum(float(w.out_of_sample_result.profit) for w in self.windows) + initial = float(self.windows[0].out_of_sample_result.initial_balance) + return (total_profit / initial * 100) if initial > 0 else 0.0 + + @property + def in_sample_profit_pct(self) -> float: + if not self.windows: + return 0.0 + total_profit = sum(float(w.in_sample_result.profit) for w in self.windows) + initial = float(self.windows[0].in_sample_result.initial_balance) + return (total_profit / initial * 100) if initial > 0 else 0.0 + + @property + def efficiency_ratio(self) -> float: + """Out-of-sample / in-sample performance ratio. + Close to 1.0 = robust. Much less = overfitting.""" + if self.in_sample_profit_pct == 0: + return 0.0 + return self.out_of_sample_profit_pct / self.in_sample_profit_pct + + +class WalkForwardEngine: + """Runs walk-forward analysis on a strategy. + + Splits candle data into N rolling windows. For each window: + 1. In-sample: Try each param set, pick the best by profit + 2. Out-of-sample: Run best params on held-out data + """ + + def __init__( + self, + strategy_factory: Callable[[], StrategyProtocol], + param_grid: list[dict], + initial_balance: Decimal = Decimal("10000"), + num_windows: int = 5, + in_sample_pct: float = 0.7, + ) -> None: + self._strategy_factory = strategy_factory + self._param_grid = param_grid + self._initial_balance = initial_balance + self._num_windows = num_windows + self._in_sample_pct = in_sample_pct + + def run(self, candles: list[Candle]) -> WalkForwardResult: + """Run walk-forward analysis over the candle data.""" + if not candles or not self._param_grid: + strategy = self._strategy_factory() + return WalkForwardResult( + strategy_name=strategy.name, + symbol=candles[0].symbol if candles else "", + num_windows=0, + ) + + total = len(candles) + window_size = total // self._num_windows + if window_size < 10: + # Not enough data for meaningful walk-forward + strategy = self._strategy_factory() + return WalkForwardResult( + strategy_name=strategy.name, + symbol=candles[0].symbol if candles else "", + num_windows=0, + ) + + windows = [] + strategy_name = self._strategy_factory().name + symbol = candles[0].symbol + + for i in range(self._num_windows): + start = i * window_size + end = min(start + window_size, total) + window_candles = candles[start:end] + + split = int(len(window_candles) * self._in_sample_pct) + in_sample = window_candles[:split] + out_of_sample = window_candles[split:] + + if len(in_sample) < 5 or len(out_of_sample) < 5: + continue + + # Optimize on in-sample + best_params = {} + best_profit = Decimal("-999999") + best_is_result = None + + for params in self._param_grid: + strategy = self._strategy_factory() + strategy.configure(params) + engine = BacktestEngine(strategy, self._initial_balance) + result = engine.run(in_sample) + if result.profit > best_profit: + best_profit = result.profit + best_params = params + best_is_result = result + + # Validate on out-of-sample with best params + strategy = self._strategy_factory() + strategy.configure(best_params) + engine = BacktestEngine(strategy, self._initial_balance) + oos_result = engine.run(out_of_sample) + + windows.append(WalkForwardWindow( + window_index=i, + in_sample_result=best_is_result, + out_of_sample_result=oos_result, + best_params=best_params, + )) + + return WalkForwardResult( + strategy_name=strategy_name, + symbol=symbol, + num_windows=len(windows), + windows=windows, + ) diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py new file mode 100644 index 0000000..e672dac --- /dev/null +++ b/services/backtester/tests/test_walk_forward.py @@ -0,0 +1,103 @@ +"""Tests for walk-forward analysis.""" +import sys +from pathlib import Path +from decimal import Decimal +from datetime import datetime, timedelta, timezone + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine")) + +from shared.models import Candle +from backtester.walk_forward import WalkForwardEngine, WalkForwardResult +from strategies.rsi_strategy import RsiStrategy + + +def _generate_candles(n=100, base_price=100.0): + candles = [] + for i in range(n): + # Simple oscillating price + price = base_price + (i % 20) - 10 + candles.append(Candle( + symbol="BTCUSDT", timeframe="1h", + open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i), + open=Decimal(str(price)), + high=Decimal(str(price + 5)), + low=Decimal(str(price - 5)), + close=Decimal(str(price)), + volume=Decimal("100"), + )) + return candles + + +def test_walk_forward_basic(): + param_grid = [ + {"period": 5, "oversold": 30, "overbought": 70, "quantity": "0.1"}, + {"period": 10, "oversold": 25, "overbought": 75, "quantity": "0.1"}, + ] + engine = WalkForwardEngine( + strategy_factory=RsiStrategy, + param_grid=param_grid, + initial_balance=Decimal("10000"), + num_windows=3, + ) + candles = _generate_candles(150) + result = engine.run(candles) + + assert isinstance(result, WalkForwardResult) + assert result.strategy_name == "rsi" + assert result.num_windows > 0 + + +def test_walk_forward_efficiency_ratio(): + param_grid = [ + {"period": 5, "oversold": 30, "overbought": 70, "quantity": "0.1"}, + ] + engine = WalkForwardEngine( + strategy_factory=RsiStrategy, + param_grid=param_grid, + num_windows=2, + ) + candles = _generate_candles(100) + result = engine.run(candles) + + # Efficiency ratio should be a finite number + assert isinstance(result.efficiency_ratio, float) + + +def test_walk_forward_empty_candles(): + engine = WalkForwardEngine( + strategy_factory=RsiStrategy, + param_grid=[{"period": 5}], + ) + result = engine.run([]) + assert result.num_windows == 0 + + +def test_walk_forward_too_few_candles(): + engine = WalkForwardEngine( + strategy_factory=RsiStrategy, + param_grid=[{"period": 5}], + num_windows=10, + ) + candles = _generate_candles(20) + result = engine.run(candles) + assert result.num_windows == 0 # window_size < 10 + + +def test_walk_forward_selects_best_params(): + param_grid = [ + {"period": 3, "oversold": 40, "overbought": 60, "quantity": "0.1"}, # very aggressive + {"period": 14, "oversold": 30, "overbought": 70, "quantity": "0.1"}, # standard + ] + engine = WalkForwardEngine( + strategy_factory=RsiStrategy, + param_grid=param_grid, + num_windows=2, + ) + candles = _generate_candles(200) + result = engine.run(candles) + + for window in result.windows: + assert window.best_params in param_grid |
