summaryrefslogtreecommitdiff
path: root/services/strategy-engine/strategies/bollinger_strategy.py
blob: 4aceee492c50d1641ede7545f8152b314bf5c329 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from collections import deque
from decimal import Decimal

import pandas as pd

from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy


class BollingerStrategy(BaseStrategy):
    name: str = "bollinger"

    def __init__(self) -> None:
        self._closes: deque[float] = deque(maxlen=500)
        self._period: int = 20
        self._num_std: float = 2.0
        self._min_bandwidth: float = 0.02
        self._quantity: Decimal = Decimal("0.01")
        self._was_below_lower: bool = False
        self._was_above_upper: bool = False

    @property
    def warmup_period(self) -> int:
        return self._period

    def configure(self, params: dict) -> None:
        self._period = int(params.get("period", 20))
        self._num_std = float(params.get("num_std", 2.0))
        self._min_bandwidth = float(params.get("min_bandwidth", 0.02))
        self._quantity = Decimal(str(params.get("quantity", "0.01")))

        if self._period < 2:
            raise ValueError(f"Bollinger period must be >= 2, got {self._period}")
        if self._num_std <= 0:
            raise ValueError(f"Bollinger num_std must be > 0, got {self._num_std}")
        if self._quantity <= 0:
            raise ValueError(f"Quantity must be positive, got {self._quantity}")

    def reset(self) -> None:
        self._closes.clear()
        self._was_below_lower = False
        self._was_above_upper = False

    def on_candle(self, candle: Candle) -> Signal | None:
        self._closes.append(float(candle.close))

        if len(self._closes) < self._period:
            return None

        series = pd.Series(list(self._closes))
        sma = series.rolling(window=self._period).mean().iloc[-1]
        std = series.rolling(window=self._period).std().iloc[-1]

        upper = sma + self._num_std * std
        lower = sma - self._num_std * std

        # Bandwidth filter: skip sideways markets
        if sma != 0 and (upper - lower) / sma < self._min_bandwidth:
            return None

        price = float(candle.close)

        # Track band penetration
        if price < lower:
            self._was_below_lower = True
        if price > upper:
            self._was_above_upper = True

        # BUY: was below lower band and recovered back inside
        if self._was_below_lower and price >= lower:
            self._was_below_lower = False
            return Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.BUY,
                price=candle.close,
                quantity=self._quantity,
                reason=f"Price recovered above lower Bollinger Band ({lower:.2f})",
            )

        # SELL: was above upper band and recovered back inside
        if self._was_above_upper and price <= upper:
            self._was_above_upper = False
            return Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.SELL,
                price=candle.close,
                quantity=self._quantity,
                reason=f"Price recovered below upper Bollinger Band ({upper:.2f})",
            )

        return None