summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 18:23:46 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 18:23:46 +0900
commit9e82c51dfde3941189db1b2d62dcc239442b9dc6 (patch)
tree2c19068a4b2e381d44d22d662ec095794e144180
parentcd3a06c7788ad8a747b1b4579fb6c45b6c43008e (diff)
feat(backtester): add walk-forward analysis engine
-rw-r--r--services/backtester/src/backtester/walk_forward.py145
-rw-r--r--services/backtester/tests/test_walk_forward.py103
2 files changed, 248 insertions, 0 deletions
diff --git a/services/backtester/src/backtester/walk_forward.py b/services/backtester/src/backtester/walk_forward.py
new file mode 100644
index 0000000..fe6d020
--- /dev/null
+++ b/services/backtester/src/backtester/walk_forward.py
@@ -0,0 +1,145 @@
+"""Walk-forward analysis for strategy parameter optimization."""
+from dataclasses import dataclass, field
+from decimal import Decimal
+from typing import Callable
+
+from shared.models import Candle
+from backtester.engine import BacktestEngine, BacktestResult, StrategyProtocol
+
+
+@dataclass
+class WalkForwardWindow:
+ """Result for a single in-sample/out-of-sample window."""
+ window_index: int
+ in_sample_result: BacktestResult
+ out_of_sample_result: BacktestResult
+ best_params: dict
+
+
+@dataclass
+class WalkForwardResult:
+ """Aggregated walk-forward analysis results."""
+ strategy_name: str
+ symbol: str
+ num_windows: int
+ windows: list[WalkForwardWindow] = field(default_factory=list)
+
+ @property
+ def out_of_sample_profit_pct(self) -> float:
+ """Combined out-of-sample profit percentage."""
+ if not self.windows:
+ return 0.0
+ total_profit = sum(float(w.out_of_sample_result.profit) for w in self.windows)
+ initial = float(self.windows[0].out_of_sample_result.initial_balance)
+ return (total_profit / initial * 100) if initial > 0 else 0.0
+
+ @property
+ def in_sample_profit_pct(self) -> float:
+ if not self.windows:
+ return 0.0
+ total_profit = sum(float(w.in_sample_result.profit) for w in self.windows)
+ initial = float(self.windows[0].in_sample_result.initial_balance)
+ return (total_profit / initial * 100) if initial > 0 else 0.0
+
+ @property
+ def efficiency_ratio(self) -> float:
+ """Out-of-sample / in-sample performance ratio.
+ Close to 1.0 = robust. Much less = overfitting."""
+ if self.in_sample_profit_pct == 0:
+ return 0.0
+ return self.out_of_sample_profit_pct / self.in_sample_profit_pct
+
+
+class WalkForwardEngine:
+ """Runs walk-forward analysis on a strategy.
+
+ Splits candle data into N rolling windows. For each window:
+ 1. In-sample: Try each param set, pick the best by profit
+ 2. Out-of-sample: Run best params on held-out data
+ """
+
+ def __init__(
+ self,
+ strategy_factory: Callable[[], StrategyProtocol],
+ param_grid: list[dict],
+ initial_balance: Decimal = Decimal("10000"),
+ num_windows: int = 5,
+ in_sample_pct: float = 0.7,
+ ) -> None:
+ self._strategy_factory = strategy_factory
+ self._param_grid = param_grid
+ self._initial_balance = initial_balance
+ self._num_windows = num_windows
+ self._in_sample_pct = in_sample_pct
+
+ def run(self, candles: list[Candle]) -> WalkForwardResult:
+ """Run walk-forward analysis over the candle data."""
+ if not candles or not self._param_grid:
+ strategy = self._strategy_factory()
+ return WalkForwardResult(
+ strategy_name=strategy.name,
+ symbol=candles[0].symbol if candles else "",
+ num_windows=0,
+ )
+
+ total = len(candles)
+ window_size = total // self._num_windows
+ if window_size < 10:
+ # Not enough data for meaningful walk-forward
+ strategy = self._strategy_factory()
+ return WalkForwardResult(
+ strategy_name=strategy.name,
+ symbol=candles[0].symbol if candles else "",
+ num_windows=0,
+ )
+
+ windows = []
+ strategy_name = self._strategy_factory().name
+ symbol = candles[0].symbol
+
+ for i in range(self._num_windows):
+ start = i * window_size
+ end = min(start + window_size, total)
+ window_candles = candles[start:end]
+
+ split = int(len(window_candles) * self._in_sample_pct)
+ in_sample = window_candles[:split]
+ out_of_sample = window_candles[split:]
+
+ if len(in_sample) < 5 or len(out_of_sample) < 5:
+ continue
+
+ # Optimize on in-sample
+ best_params = {}
+ best_profit = Decimal("-999999")
+ best_is_result = None
+
+ for params in self._param_grid:
+ strategy = self._strategy_factory()
+ strategy.configure(params)
+ engine = BacktestEngine(strategy, self._initial_balance)
+ result = engine.run(in_sample)
+ if result.profit > best_profit:
+ best_profit = result.profit
+ best_params = params
+ best_is_result = result
+
+ # Validate on out-of-sample with best params
+ strategy = self._strategy_factory()
+ strategy.configure(best_params)
+ engine = BacktestEngine(strategy, self._initial_balance)
+ oos_result = engine.run(out_of_sample)
+
+ windows.append(WalkForwardWindow(
+ window_index=i,
+ in_sample_result=best_is_result,
+ out_of_sample_result=oos_result,
+ best_params=best_params,
+ ))
+
+ return WalkForwardResult(
+ strategy_name=strategy_name,
+ symbol=symbol,
+ num_windows=len(windows),
+ windows=windows,
+ )
diff --git a/services/backtester/tests/test_walk_forward.py b/services/backtester/tests/test_walk_forward.py
new file mode 100644
index 0000000..e672dac
--- /dev/null
+++ b/services/backtester/tests/test_walk_forward.py
@@ -0,0 +1,103 @@
+"""Tests for walk-forward analysis."""
+import sys
+from pathlib import Path
+from decimal import Decimal
+from datetime import datetime, timedelta, timezone
+
+import pytest
+
+sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
+sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine"))
+
+from shared.models import Candle
+from backtester.walk_forward import WalkForwardEngine, WalkForwardResult
+from strategies.rsi_strategy import RsiStrategy
+
+
+def _generate_candles(n=100, base_price=100.0):
+ candles = []
+ for i in range(n):
+ # Simple oscillating price
+ price = base_price + (i % 20) - 10
+ candles.append(Candle(
+ symbol="BTCUSDT", timeframe="1h",
+ open_time=datetime(2025, 1, 1, tzinfo=timezone.utc) + timedelta(hours=i),
+ open=Decimal(str(price)),
+ high=Decimal(str(price + 5)),
+ low=Decimal(str(price - 5)),
+ close=Decimal(str(price)),
+ volume=Decimal("100"),
+ ))
+ return candles
+
+
+def test_walk_forward_basic():
+ param_grid = [
+ {"period": 5, "oversold": 30, "overbought": 70, "quantity": "0.1"},
+ {"period": 10, "oversold": 25, "overbought": 75, "quantity": "0.1"},
+ ]
+ engine = WalkForwardEngine(
+ strategy_factory=RsiStrategy,
+ param_grid=param_grid,
+ initial_balance=Decimal("10000"),
+ num_windows=3,
+ )
+ candles = _generate_candles(150)
+ result = engine.run(candles)
+
+ assert isinstance(result, WalkForwardResult)
+ assert result.strategy_name == "rsi"
+ assert result.num_windows > 0
+
+
+def test_walk_forward_efficiency_ratio():
+ param_grid = [
+ {"period": 5, "oversold": 30, "overbought": 70, "quantity": "0.1"},
+ ]
+ engine = WalkForwardEngine(
+ strategy_factory=RsiStrategy,
+ param_grid=param_grid,
+ num_windows=2,
+ )
+ candles = _generate_candles(100)
+ result = engine.run(candles)
+
+ # Efficiency ratio should be a finite number
+ assert isinstance(result.efficiency_ratio, float)
+
+
+def test_walk_forward_empty_candles():
+ engine = WalkForwardEngine(
+ strategy_factory=RsiStrategy,
+ param_grid=[{"period": 5}],
+ )
+ result = engine.run([])
+ assert result.num_windows == 0
+
+
+def test_walk_forward_too_few_candles():
+ engine = WalkForwardEngine(
+ strategy_factory=RsiStrategy,
+ param_grid=[{"period": 5}],
+ num_windows=10,
+ )
+ candles = _generate_candles(20)
+ result = engine.run(candles)
+ assert result.num_windows == 0 # window_size < 10
+
+
+def test_walk_forward_selects_best_params():
+ param_grid = [
+ {"period": 3, "oversold": 40, "overbought": 60, "quantity": "0.1"}, # very aggressive
+ {"period": 14, "oversold": 30, "overbought": 70, "quantity": "0.1"}, # standard
+ ]
+ engine = WalkForwardEngine(
+ strategy_factory=RsiStrategy,
+ param_grid=param_grid,
+ num_windows=2,
+ )
+ candles = _generate_candles(200)
+ result = engine.run(candles)
+
+ for window in result.windows:
+ assert window.best_params in param_grid