summaryrefslogtreecommitdiff
path: root/services/strategy-engine/strategies/indicators/volatility.py
blob: c16143e44521e97dfb765066bda759d58fbf95d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Volatility indicators: ATR, Bollinger Bands, Keltner Channels."""

import pandas as pd
import numpy as np


def atr(
    highs: pd.Series,
    lows: pd.Series,
    closes: pd.Series,
    period: int = 14,
) -> pd.Series:
    """Average True Range using Wilder's smoothing."""
    high = highs.values
    low = lows.values
    close = closes.values
    n = len(close)

    tr = np.zeros(n)
    tr[0] = high[0] - low[0]
    for i in range(1, n):
        tr[i] = max(
            high[i] - low[i],
            abs(high[i] - close[i - 1]),
            abs(low[i] - close[i - 1]),
        )

    atr_vals = np.full(n, np.nan)
    if n >= period:
        atr_vals[period - 1] = np.mean(tr[:period])
        for i in range(period, n):
            atr_vals[i] = (atr_vals[i - 1] * (period - 1) + tr[i]) / period

    return pd.Series(atr_vals, index=closes.index if hasattr(closes, "index") else None)


def bollinger_bands(
    closes: pd.Series,
    period: int = 20,
    num_std: float = 2.0,
) -> tuple[pd.Series, pd.Series, pd.Series]:
    """Bollinger Bands.

    Returns: (upper_band, middle_band, lower_band)
    """
    middle = closes.rolling(window=period).mean()
    std = closes.rolling(window=period).std()
    upper = middle + num_std * std
    lower = middle - num_std * std
    return upper, middle, lower


def keltner_channels(
    highs: pd.Series,
    lows: pd.Series,
    closes: pd.Series,
    ema_period: int = 20,
    atr_period: int = 14,
    atr_multiplier: float = 2.0,
) -> tuple[pd.Series, pd.Series, pd.Series]:
    """Keltner Channels.

    Returns: (upper_channel, middle_ema, lower_channel)
    """
    from strategies.indicators.trend import ema as calc_ema

    middle = calc_ema(closes, ema_period)
    atr_vals = atr(highs, lows, closes, atr_period)
    upper = middle + atr_multiplier * atr_vals
    lower = middle - atr_multiplier * atr_vals
    return upper, middle, lower