summaryrefslogtreecommitdiff
path: root/services/backtester/tests/test_walk_forward.py
blob: b1aa12c7f7269effc74f74683130c99fed6822fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Tests for walk-forward analysis."""

import sys
from datetime import UTC, datetime, timedelta
from decimal import Decimal
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "strategy-engine"))

from backtester.walk_forward import WalkForwardEngine, WalkForwardResult
from strategies.rsi_strategy import RsiStrategy

from shared.models import Candle


def _generate_candles(n=100, base_price=100.0):
    candles = []
    for i in range(n):
        # Simple oscillating price
        price = base_price + (i % 20) - 10
        candles.append(
            Candle(
                symbol="AAPL",
                timeframe="1h",
                open_time=datetime(2025, 1, 1, tzinfo=UTC) + timedelta(hours=i),
                open=Decimal(str(price)),
                high=Decimal(str(price + 5)),
                low=Decimal(str(price - 5)),
                close=Decimal(str(price)),
                volume=Decimal("100"),
            )
        )
    return candles


def test_walk_forward_basic():
    param_grid = [
        {"period": 5, "oversold": 30, "overbought": 70, "quantity": "0.1"},
        {"period": 10, "oversold": 25, "overbought": 75, "quantity": "0.1"},
    ]
    engine = WalkForwardEngine(
        strategy_factory=RsiStrategy,
        param_grid=param_grid,
        initial_balance=Decimal("10000"),
        num_windows=3,
    )
    candles = _generate_candles(150)
    result = engine.run(candles)

    assert isinstance(result, WalkForwardResult)
    assert result.strategy_name == "rsi"
    assert result.num_windows > 0


def test_walk_forward_efficiency_ratio():
    param_grid = [
        {"period": 5, "oversold": 30, "overbought": 70, "quantity": "0.1"},
    ]
    engine = WalkForwardEngine(
        strategy_factory=RsiStrategy,
        param_grid=param_grid,
        num_windows=2,
    )
    candles = _generate_candles(100)
    result = engine.run(candles)

    # Efficiency ratio should be a finite number
    assert isinstance(result.efficiency_ratio, float)


def test_walk_forward_empty_candles():
    engine = WalkForwardEngine(
        strategy_factory=RsiStrategy,
        param_grid=[{"period": 5}],
    )
    result = engine.run([])
    assert result.num_windows == 0


def test_walk_forward_too_few_candles():
    engine = WalkForwardEngine(
        strategy_factory=RsiStrategy,
        param_grid=[{"period": 5}],
        num_windows=10,
    )
    candles = _generate_candles(20)
    result = engine.run(candles)
    assert result.num_windows == 0  # window_size < 10


def test_walk_forward_selects_best_params():
    param_grid = [
        {"period": 3, "oversold": 40, "overbought": 60, "quantity": "0.1"},  # very aggressive
        {"period": 14, "oversold": 30, "overbought": 70, "quantity": "0.1"},  # standard
    ]
    engine = WalkForwardEngine(
        strategy_factory=RsiStrategy,
        param_grid=param_grid,
        num_windows=2,
    )
    candles = _generate_candles(200)
    result = engine.run(candles)

    for window in result.windows:
        assert window.best_params in param_grid