"""Tests that strategy reset() properly clears internal state.""" import sys from datetime import datetime, timezone from decimal import Decimal from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine")) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src")) from shared.models import Candle from strategies.rsi_strategy import RsiStrategy from strategies.grid_strategy import GridStrategy from strategies.macd_strategy import MacdStrategy from strategies.bollinger_strategy import BollingerStrategy from strategies.ema_crossover_strategy import EmaCrossoverStrategy from strategies.vwap_strategy import VwapStrategy from strategies.volume_profile_strategy import VolumeProfileStrategy def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]: """Generate a list of candles with slight price variation.""" candles = [] for i in range(count): # Oscillating price to potentially trigger signals price = base_price + (i % 10) - 5 candles.append( Candle( symbol="AAPL", timeframe="1h", open_time=datetime(2025, 1, 1, i % 24, tzinfo=timezone.utc), open=Decimal(str(price)), high=Decimal(str(price + 1)), low=Decimal(str(price - 1)), close=Decimal(str(price)), volume=Decimal("1000"), ) ) return candles def _collect_signals(strategy, candles): """Feed candles through a strategy and collect all signals.""" signals = [] for c in candles: s = strategy.on_candle(c) if s is not None: signals.append(s) return signals class TestRsiReset: def test_reset_produces_same_signals(self): strategy = RsiStrategy() candles = _make_candles(50) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price class TestGridReset: def test_reset_produces_same_signals(self): strategy = GridStrategy() strategy.configure({"lower_price": 90, "upper_price": 110, "grid_count": 5}) candles = _make_candles(50) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price class TestMacdReset: def test_reset_produces_same_signals(self): strategy = MacdStrategy() candles = _make_candles(60) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price class TestBollingerReset: def test_reset_produces_same_signals(self): strategy = BollingerStrategy() candles = _make_candles(50) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price class TestEmaCrossoverReset: def test_reset_produces_same_signals(self): strategy = EmaCrossoverStrategy() candles = _make_candles(50) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price class TestVwapReset: def test_reset_produces_same_signals(self): strategy = VwapStrategy() candles = _make_candles(50) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price class TestVolumeProfileReset: def test_reset_produces_same_signals(self): strategy = VolumeProfileStrategy() candles = _make_candles(150) signals1 = _collect_signals(strategy, candles) strategy.reset() signals2 = _collect_signals(strategy, candles) assert len(signals1) == len(signals2) for s1, s2 in zip(signals1, signals2): assert s1.side == s2.side assert s1.price == s2.price