from abc import ABC, abstractmethod from collections import deque from decimal import Decimal from typing import Optional import pandas as pd from shared.models import Candle, Signal from strategies.indicators.trend import adx from strategies.indicators.volatility import atr from strategies.indicators.volume import volume_ratio class BaseStrategy(ABC): name: str = "base" def __init__(self) -> None: # Filter state — subclasses can enable by calling _init_filters() in their __init__ self._filter_enabled: bool = False self._highs: deque[float] = deque(maxlen=500) self._lows: deque[float] = deque(maxlen=500) self._closes_filter: deque[float] = deque(maxlen=500) self._volumes: deque[float] = deque(maxlen=500) # Filter config self._adx_period: int = 14 self._adx_threshold: float = 25.0 self._require_trend: bool = True # True = trend-following, False = mean-reversion self._volume_period: int = 20 self._min_volume_ratio: float = 0.5 self._atr_period: int = 14 self._atr_stop_multiplier: float = 2.0 self._atr_tp_multiplier: float = 3.0 def _init_filters( self, adx_period: int = 14, adx_threshold: float = 25.0, require_trend: bool = True, volume_period: int = 20, min_volume_ratio: float = 0.5, atr_period: int = 14, atr_stop_multiplier: float = 2.0, atr_tp_multiplier: float = 3.0, ) -> None: """Enable filters. Call from subclass __init__ or configure().""" self._filter_enabled = True self._adx_period = adx_period self._adx_threshold = adx_threshold self._require_trend = require_trend self._volume_period = volume_period self._min_volume_ratio = min_volume_ratio self._atr_period = atr_period self._atr_stop_multiplier = atr_stop_multiplier self._atr_tp_multiplier = atr_tp_multiplier def _update_filter_data(self, candle: Candle) -> None: """Call at the start of on_candle() to track filter data.""" if self._filter_enabled: self._highs.append(float(candle.high)) self._lows.append(float(candle.low)) self._closes_filter.append(float(candle.close)) self._volumes.append(float(candle.volume)) def _check_regime(self) -> bool: """Check if current market regime matches strategy type. Returns True if signal should be allowed. - require_trend=True (trend strategies): only trade when ADX > threshold - require_trend=False (mean-reversion): only trade when ADX < threshold """ if not self._filter_enabled: return True if len(self._closes_filter) < self._adx_period * 2 + 1: return True # Not enough data, allow by default highs = pd.Series(list(self._highs)) lows = pd.Series(list(self._lows)) closes = pd.Series(list(self._closes_filter)) adx_value = adx(highs, lows, closes, self._adx_period).iloc[-1] if pd.isna(adx_value): return True if self._require_trend: return adx_value >= self._adx_threshold return adx_value < self._adx_threshold def _check_volume(self) -> bool: """Check if current volume is sufficient (above minimum ratio of average).""" if not self._filter_enabled: return True if len(self._volumes) < self._volume_period + 1: return True volumes = pd.Series(list(self._volumes)) ratio = volume_ratio(volumes, self._volume_period).iloc[-1] if pd.isna(ratio): return True return ratio >= self._min_volume_ratio def _calculate_atr_stops( self, entry_price: Decimal, side: str ) -> tuple[Optional[Decimal], Optional[Decimal]]: """Calculate ATR-based stop-loss and take-profit. Returns (stop_loss, take_profit) as Decimal or (None, None) if not enough data. """ if not self._filter_enabled: return None, None if len(self._closes_filter) < self._atr_period + 1: return None, None highs = pd.Series(list(self._highs)) lows = pd.Series(list(self._lows)) closes = pd.Series(list(self._closes_filter)) atr_value = atr(highs, lows, closes, self._atr_period).iloc[-1] if pd.isna(atr_value) or atr_value == 0: return None, None atr_dec = Decimal(str(atr_value)) if side == "BUY": sl = entry_price - atr_dec * Decimal(str(self._atr_stop_multiplier)) tp = entry_price + atr_dec * Decimal(str(self._atr_tp_multiplier)) else: # SELL sl = entry_price + atr_dec * Decimal(str(self._atr_stop_multiplier)) tp = entry_price - atr_dec * Decimal(str(self._atr_tp_multiplier)) return sl, tp def _apply_filters(self, signal: Signal) -> Optional[Signal]: """Apply all filters to a signal. Returns signal with SL/TP or None if filtered out.""" if signal is None: return None if not self._check_regime(): return None if not self._check_volume(): return None # Add ATR-based stops sl, tp = self._calculate_atr_stops(signal.price, signal.side.value) if sl is not None: signal.stop_loss = sl if tp is not None: signal.take_profit = tp return signal @property @abstractmethod def warmup_period(self) -> int: pass @abstractmethod def on_candle(self, candle: Candle) -> Signal | None: pass @abstractmethod def configure(self, params: dict) -> None: pass def validate_params(self, params: dict) -> list[str]: """Validate parameters and return list of error messages. Empty = valid.""" return [] def reset(self) -> None: if self._filter_enabled: self._highs.clear() self._lows.clear() self._closes_filter.clear() self._volumes.clear()