summaryrefslogtreecommitdiff
path: root/services/strategy-engine/strategies/grid_strategy.py
blob: 70443ec032475b43261dc99dd179d769a006dad9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from decimal import Decimal
from typing import Optional

import numpy as np

from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy


class GridStrategy(BaseStrategy):
    name: str = "grid"

    def __init__(self) -> None:
        super().__init__()
        self._lower_price: float = 0.0
        self._upper_price: float = 0.0
        self._grid_count: int = 5
        self._quantity: Decimal = Decimal("0.01")
        self._grid_levels: list[float] = []
        self._last_zone: Optional[int] = None

    @property
    def warmup_period(self) -> int:
        return 2

    def configure(self, params: dict) -> None:
        self._lower_price = float(params["lower_price"])
        self._upper_price = float(params["upper_price"])
        self._grid_count = int(params.get("grid_count", 5))
        self._quantity = Decimal(str(params.get("quantity", "0.01")))

        if self._lower_price >= self._upper_price:
            raise ValueError(
                f"Grid lower_price must be < upper_price, "
                f"got lower={self._lower_price}, upper={self._upper_price}"
            )
        if self._grid_count < 2:
            raise ValueError(f"Grid grid_count must be >= 2, got {self._grid_count}")
        if self._quantity <= 0:
            raise ValueError(f"Quantity must be positive, got {self._quantity}")

        self._grid_levels = list(
            np.linspace(self._lower_price, self._upper_price, self._grid_count + 1)
        )
        self._last_zone = None

        self._init_filters(
            require_trend=False,
            adx_threshold=float(params.get("adx_threshold", 20.0)),
            min_volume_ratio=float(params.get("min_volume_ratio", 0.5)),
            atr_stop_multiplier=float(params.get("atr_stop_multiplier", 2.0)),
            atr_tp_multiplier=float(params.get("atr_tp_multiplier", 3.0)),
        )

    def reset(self) -> None:
        self._last_zone = None

    def _get_zone(self, price: float) -> int:
        """Return the grid zone index for a given price.

        Zone 0 is below the lowest level, zone grid_count is above the highest level.
        Zones 1..grid_count-1 are between levels.
        """
        for i, level in enumerate(self._grid_levels):
            if price < level:
                return i
        return len(self._grid_levels)

    def on_candle(self, candle: Candle) -> Signal | None:
        self._update_filter_data(candle)
        price = float(candle.close)
        current_zone = self._get_zone(price)

        if self._last_zone is None:
            self._last_zone = current_zone
            return None

        prev_zone = self._last_zone
        self._last_zone = current_zone

        if current_zone < prev_zone:
            # Price moved to a lower zone → BUY
            signal = Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.BUY,
                price=candle.close,
                quantity=self._quantity,
                conviction=0.5,
                reason=f"Grid: price crossed down from zone {prev_zone} to {current_zone}",
            )
            return self._apply_filters(signal)
        elif current_zone > prev_zone:
            # Price moved to a higher zone → SELL
            signal = Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.SELL,
                price=candle.close,
                quantity=self._quantity,
                conviction=0.5,
                reason=f"Grid: price crossed up from zone {prev_zone} to {current_zone}",
            )
            return self._apply_filters(signal)

        return None