summaryrefslogtreecommitdiff
path: root/tests/edge_cases/test_strategy_reset.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/edge_cases/test_strategy_reset.py')
-rw-r--r--tests/edge_cases/test_strategy_reset.py141
1 files changed, 141 insertions, 0 deletions
diff --git a/tests/edge_cases/test_strategy_reset.py b/tests/edge_cases/test_strategy_reset.py
new file mode 100644
index 0000000..f84adf0
--- /dev/null
+++ b/tests/edge_cases/test_strategy_reset.py
@@ -0,0 +1,141 @@
+"""Tests that strategy reset() properly clears internal state."""
+
+import sys
+from datetime import datetime, timezone
+from decimal import Decimal
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "strategy-engine"))
+sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "services" / "backtester" / "src"))
+
+from shared.models import Candle
+from strategies.rsi_strategy import RsiStrategy
+from strategies.grid_strategy import GridStrategy
+from strategies.macd_strategy import MacdStrategy
+from strategies.bollinger_strategy import BollingerStrategy
+from strategies.ema_crossover_strategy import EmaCrossoverStrategy
+from strategies.vwap_strategy import VwapStrategy
+from strategies.volume_profile_strategy import VolumeProfileStrategy
+
+
+def _make_candles(count: int, base_price: float = 100.0) -> list[Candle]:
+ """Generate a list of candles with slight price variation."""
+ candles = []
+ for i in range(count):
+ # Oscillating price to potentially trigger signals
+ price = base_price + (i % 10) - 5
+ candles.append(
+ Candle(
+ symbol="BTCUSDT",
+ timeframe="1h",
+ open_time=datetime(2025, 1, 1, i % 24, tzinfo=timezone.utc),
+ open=Decimal(str(price)),
+ high=Decimal(str(price + 1)),
+ low=Decimal(str(price - 1)),
+ close=Decimal(str(price)),
+ volume=Decimal("1000"),
+ )
+ )
+ return candles
+
+
+def _collect_signals(strategy, candles):
+ """Feed candles through a strategy and collect all signals."""
+ signals = []
+ for c in candles:
+ s = strategy.on_candle(c)
+ if s is not None:
+ signals.append(s)
+ return signals
+
+
+class TestRsiReset:
+ def test_reset_produces_same_signals(self):
+ strategy = RsiStrategy()
+ candles = _make_candles(50)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price
+
+
+class TestGridReset:
+ def test_reset_produces_same_signals(self):
+ strategy = GridStrategy()
+ strategy.configure({"lower_price": 90, "upper_price": 110, "grid_count": 5})
+ candles = _make_candles(50)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price
+
+
+class TestMacdReset:
+ def test_reset_produces_same_signals(self):
+ strategy = MacdStrategy()
+ candles = _make_candles(60)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price
+
+
+class TestBollingerReset:
+ def test_reset_produces_same_signals(self):
+ strategy = BollingerStrategy()
+ candles = _make_candles(50)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price
+
+
+class TestEmaCrossoverReset:
+ def test_reset_produces_same_signals(self):
+ strategy = EmaCrossoverStrategy()
+ candles = _make_candles(50)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price
+
+
+class TestVwapReset:
+ def test_reset_produces_same_signals(self):
+ strategy = VwapStrategy()
+ candles = _make_candles(50)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price
+
+
+class TestVolumeProfileReset:
+ def test_reset_produces_same_signals(self):
+ strategy = VolumeProfileStrategy()
+ candles = _make_candles(150)
+ signals1 = _collect_signals(strategy, candles)
+ strategy.reset()
+ signals2 = _collect_signals(strategy, candles)
+ assert len(signals1) == len(signals2)
+ for s1, s2 in zip(signals1, signals2):
+ assert s1.side == s2.side
+ assert s1.price == s2.price