summaryrefslogtreecommitdiff
path: root/services/strategy-engine/strategies/grid_strategy.py
blob: 283bfe5652e99f3a1dbfcc5959162d9305db2dc8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from decimal import Decimal
from typing import Optional

import numpy as np

from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy


class GridStrategy(BaseStrategy):
    name: str = "grid"

    def __init__(self) -> None:
        super().__init__()
        self._lower_price: float = 0.0
        self._upper_price: float = 0.0
        self._grid_count: int = 5
        self._quantity: Decimal = Decimal("0.01")
        self._grid_levels: list[float] = []
        self._last_zone: Optional[int] = None
        self._exit_threshold_pct: float = 5.0
        self._out_of_range: bool = False
        self._in_position: bool = False  # Track if we have any grid positions

    @property
    def warmup_period(self) -> int:
        return 2

    def configure(self, params: dict) -> None:
        self._lower_price = float(params["lower_price"])
        self._upper_price = float(params["upper_price"])
        self._grid_count = int(params.get("grid_count", 5))
        self._quantity = Decimal(str(params.get("quantity", "0.01")))

        self._exit_threshold_pct = float(params.get("exit_threshold_pct", 5.0))

        if self._lower_price >= self._upper_price:
            raise ValueError(
                f"Grid lower_price must be < upper_price, "
                f"got lower={self._lower_price}, upper={self._upper_price}"
            )
        if self._exit_threshold_pct <= 0:
            raise ValueError(f"exit_threshold_pct must be > 0, got {self._exit_threshold_pct}")
        if self._grid_count < 2:
            raise ValueError(f"Grid grid_count must be >= 2, got {self._grid_count}")
        if self._quantity <= 0:
            raise ValueError(f"Quantity must be positive, got {self._quantity}")

        self._grid_levels = list(
            np.linspace(self._lower_price, self._upper_price, self._grid_count + 1)
        )
        self._last_zone = None

        self._init_filters(
            require_trend=False,
            adx_threshold=float(params.get("adx_threshold", 20.0)),
            min_volume_ratio=float(params.get("min_volume_ratio", 0.5)),
            atr_stop_multiplier=float(params.get("atr_stop_multiplier", 2.0)),
            atr_tp_multiplier=float(params.get("atr_tp_multiplier", 3.0)),
        )

    def reset(self) -> None:
        super().reset()
        self._last_zone = None
        self._out_of_range = False

    def _get_zone(self, price: float) -> int:
        """Return the grid zone index for a given price.

        Zone 0 is below the lowest level, zone grid_count is above the highest level.
        Zones 1..grid_count-1 are between levels.
        """
        for i, level in enumerate(self._grid_levels):
            if price < level:
                return i
        return len(self._grid_levels)

    def on_candle(self, candle: Candle) -> Signal | None:
        self._update_filter_data(candle)
        price = float(candle.close)

        # Check if price is out of grid range
        if self._grid_levels:
            lower_bound = self._grid_levels[0] * (1 - self._exit_threshold_pct / 100)
            upper_bound = self._grid_levels[-1] * (1 + self._exit_threshold_pct / 100)

            if price < lower_bound or price > upper_bound:
                if not self._out_of_range:
                    self._out_of_range = True
                    # Exit signal — close positions
                    return self._apply_filters(
                        Signal(
                            strategy=self.name,
                            symbol=candle.symbol,
                            side=OrderSide.SELL,
                            price=candle.close,
                            quantity=self._quantity,
                            conviction=0.8,
                            reason=f"Grid: price {price:.2f} broke out of range [{self._grid_levels[0]:.2f}, {self._grid_levels[-1]:.2f}]",
                        )
                    )
                return None  # Already out of range, no more signals
            else:
                self._out_of_range = False

        current_zone = self._get_zone(price)

        if self._last_zone is None:
            self._last_zone = current_zone
            return None

        prev_zone = self._last_zone
        self._last_zone = current_zone

        if current_zone < prev_zone:
            # Price moved to a lower zone → BUY
            signal = Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.BUY,
                price=candle.close,
                quantity=self._quantity,
                conviction=0.5,
                reason=f"Grid: price crossed down from zone {prev_zone} to {current_zone}",
            )
            return self._apply_filters(signal)
        elif current_zone > prev_zone:
            # Price moved to a higher zone → SELL
            signal = Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.SELL,
                price=candle.close,
                quantity=self._quantity,
                conviction=0.5,
                reason=f"Grid: price crossed up from zone {prev_zone} to {current_zone}",
            )
            return self._apply_filters(signal)

        return None