summaryrefslogtreecommitdiff
path: root/services/strategy-engine/strategies/base.py
blob: 1d9d289e87fdc6084f41de13812368a5a989642c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from abc import ABC, abstractmethod
from collections import deque
from decimal import Decimal

import pandas as pd

from shared.models import Candle, Signal
from strategies.indicators.trend import adx
from strategies.indicators.volatility import atr
from strategies.indicators.volume import volume_ratio


class BaseStrategy(ABC):
    name: str = "base"

    def __init__(self) -> None:
        # Filter state — subclasses can enable by calling _init_filters() in their __init__
        self._filter_enabled: bool = False
        self._highs: deque[float] = deque(maxlen=500)
        self._lows: deque[float] = deque(maxlen=500)
        self._closes_filter: deque[float] = deque(maxlen=500)
        self._volumes: deque[float] = deque(maxlen=500)
        # Filter config
        self._adx_period: int = 14
        self._adx_threshold: float = 25.0
        self._require_trend: bool = True  # True = trend-following, False = mean-reversion
        self._volume_period: int = 20
        self._min_volume_ratio: float = 0.5
        self._atr_period: int = 14
        self._atr_stop_multiplier: float = 2.0
        self._atr_tp_multiplier: float = 3.0

    def _init_filters(
        self,
        adx_period: int = 14,
        adx_threshold: float = 25.0,
        require_trend: bool = True,
        volume_period: int = 20,
        min_volume_ratio: float = 0.5,
        atr_period: int = 14,
        atr_stop_multiplier: float = 2.0,
        atr_tp_multiplier: float = 3.0,
    ) -> None:
        """Enable filters. Call from subclass __init__ or configure()."""
        self._filter_enabled = True
        self._adx_period = adx_period
        self._adx_threshold = adx_threshold
        self._require_trend = require_trend
        self._volume_period = volume_period
        self._min_volume_ratio = min_volume_ratio
        self._atr_period = atr_period
        self._atr_stop_multiplier = atr_stop_multiplier
        self._atr_tp_multiplier = atr_tp_multiplier

    def _update_filter_data(self, candle: Candle) -> None:
        """Call at the start of on_candle() to track filter data."""
        if self._filter_enabled:
            self._highs.append(float(candle.high))
            self._lows.append(float(candle.low))
            self._closes_filter.append(float(candle.close))
            self._volumes.append(float(candle.volume))

    def _check_regime(self) -> bool:
        """Check if current market regime matches strategy type.

        Returns True if signal should be allowed.
        - require_trend=True (trend strategies): only trade when ADX > threshold
        - require_trend=False (mean-reversion): only trade when ADX < threshold
        """
        if not self._filter_enabled:
            return True
        if len(self._closes_filter) < self._adx_period * 2 + 1:
            return True  # Not enough data, allow by default

        highs = pd.Series(list(self._highs))
        lows = pd.Series(list(self._lows))
        closes = pd.Series(list(self._closes_filter))
        adx_value = adx(highs, lows, closes, self._adx_period).iloc[-1]

        if pd.isna(adx_value):
            return True

        if self._require_trend:
            return adx_value >= self._adx_threshold
        return adx_value < self._adx_threshold

    def _check_volume(self) -> bool:
        """Check if current volume is sufficient (above minimum ratio of average)."""
        if not self._filter_enabled:
            return True
        if len(self._volumes) < self._volume_period + 1:
            return True

        volumes = pd.Series(list(self._volumes))
        ratio = volume_ratio(volumes, self._volume_period).iloc[-1]

        if pd.isna(ratio):
            return True

        return ratio >= self._min_volume_ratio

    def _calculate_atr_stops(
        self, entry_price: Decimal, side: str
    ) -> tuple[Decimal | None, Decimal | None]:
        """Calculate ATR-based stop-loss and take-profit.

        Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data.
        """
        if not self._filter_enabled:
            return None, None
        if len(self._closes_filter) < self._atr_period + 1:
            return None, None

        highs = pd.Series(list(self._highs))
        lows = pd.Series(list(self._lows))
        closes = pd.Series(list(self._closes_filter))
        atr_value = atr(highs, lows, closes, self._atr_period).iloc[-1]

        if pd.isna(atr_value) or atr_value == 0:
            return None, None

        atr_dec = Decimal(str(atr_value))

        if side == "BUY":
            sl = entry_price - atr_dec * Decimal(str(self._atr_stop_multiplier))
            tp = entry_price + atr_dec * Decimal(str(self._atr_tp_multiplier))
        else:  # SELL
            sl = entry_price + atr_dec * Decimal(str(self._atr_stop_multiplier))
            tp = entry_price - atr_dec * Decimal(str(self._atr_tp_multiplier))

        return sl, tp

    def _apply_filters(self, signal: Signal) -> Signal | None:
        """Apply all filters to a signal. Returns signal with SL/TP or None if filtered out."""
        if signal is None:
            return None

        if not self._check_regime():
            return None

        if not self._check_volume():
            return None

        # Add ATR-based stops
        sl, tp = self._calculate_atr_stops(signal.price, signal.side.value)
        if sl is not None:
            signal.stop_loss = sl
        if tp is not None:
            signal.take_profit = tp

        return signal

    @property
    @abstractmethod
    def warmup_period(self) -> int:
        pass

    @abstractmethod
    def on_candle(self, candle: Candle) -> Signal | None:
        pass

    @abstractmethod
    def configure(self, params: dict) -> None:
        pass

    def validate_params(self, params: dict) -> list[str]:
        """Validate parameters and return list of error messages. Empty = valid."""
        return []

    def reset(self) -> None:
        if self._filter_enabled:
            self._highs.clear()
            self._lows.clear()
            self._closes_filter.clear()
            self._volumes.clear()