diff options
Diffstat (limited to 'tests/edge_cases/test_strategy_reset.py')
| -rw-r--r-- | tests/edge_cases/test_strategy_reset.py | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/tests/edge_cases/test_strategy_reset.py b/tests/edge_cases/test_strategy_reset.py new file mode 100644 index 0000000..f84adf0 --- /dev/null +++ b/tests/edge_cases/test_strategy_reset.py @@ -0,0 +1,141 @@ +"""Tests that strategy reset() properly clears internal state.""" + +import sys +from datetime import datetime, timezone +from decimal import Decimal +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine")) +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src")) + +from shared.models import Candle +from strategies.rsi_strategy import RsiStrategy +from strategies.grid_strategy import GridStrategy +from strategies.macd_strategy import MacdStrategy +from strategies.bollinger_strategy import BollingerStrategy +from strategies.ema_crossover_strategy import EmaCrossoverStrategy +from strategies.vwap_strategy import VwapStrategy +from strategies.volume_profile_strategy import VolumeProfileStrategy + + +def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]: + """Generate a list of candles with slight price variation.""" + candles = [] + for i in range(count): + # Oscillating price to potentially trigger signals + price = base_price + (i % 10) - 5 + candles.append( + Candle( + symbol="BTCUSDT", + timeframe="1h", + open_time=datetime(2025, 1, 1, i % 24, tzinfo=timezone.utc), + open=Decimal(str(price)), + high=Decimal(str(price + 1)), + low=Decimal(str(price - 1)), + close=Decimal(str(price)), + volume=Decimal("1000"), + ) + ) + return candles + + +def _collect_signals(strategy, candles): + """Feed candles through a strategy and collect all signals.""" + signals = [] + for c in candles: + s = strategy.on_candle(c) + if s is not None: + signals.append(s) + return signals + + +class TestRsiReset: + def test_reset_produces_same_signals(self): + strategy = RsiStrategy() + candles = _make_candles(50) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price + + +class TestGridReset: + def test_reset_produces_same_signals(self): + strategy = GridStrategy() + strategy.configure({"lower_price": 90, "upper_price": 110, "grid_count": 5}) + candles = _make_candles(50) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price + + +class TestMacdReset: + def test_reset_produces_same_signals(self): + strategy = MacdStrategy() + candles = _make_candles(60) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price + + +class TestBollingerReset: + def test_reset_produces_same_signals(self): + strategy = BollingerStrategy() + candles = _make_candles(50) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price + + +class TestEmaCrossoverReset: + def test_reset_produces_same_signals(self): + strategy = EmaCrossoverStrategy() + candles = _make_candles(50) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price + + +class TestVwapReset: + def test_reset_produces_same_signals(self): + strategy = VwapStrategy() + candles = _make_candles(50) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price + + +class TestVolumeProfileReset: + def test_reset_produces_same_signals(self): + strategy = VolumeProfileStrategy() + candles = _make_candles(150) + signals1 = _collect_signals(strategy, candles) + strategy.reset() + signals2 = _collect_signals(strategy, candles) + assert len(signals1) == len(signals2) + for s1, s2 in zip(signals1, signals2): + assert s1.side == s2.side + assert s1.price == s2.price |
