summaryrefslogtreecommitdiff
path: root/services/strategy-engine/strategies/grid_strategy.py
blob: 78e27031085cf0e103495c667cad403785c68a77 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from decimal import Decimal
from typing import Optional

import numpy as np

from shared.models import Candle, Signal, OrderSide
from strategies.base import BaseStrategy


class GridStrategy(BaseStrategy):
    name: str = "grid"

    def __init__(self) -> None:
        self._lower_price: float = 0.0
        self._upper_price: float = 0.0
        self._grid_count: int = 5
        self._quantity: Decimal = Decimal("0.01")
        self._grid_levels: list[float] = []
        self._last_zone: Optional[int] = None

    @property
    def warmup_period(self) -> int:
        return 2

    def configure(self, params: dict) -> None:
        self._lower_price = float(params["lower_price"])
        self._upper_price = float(params["upper_price"])
        self._grid_count = int(params.get("grid_count", 5))
        self._quantity = Decimal(str(params.get("quantity", "0.01")))
        self._grid_levels = list(
            np.linspace(self._lower_price, self._upper_price, self._grid_count + 1)
        )
        self._last_zone = None

    def reset(self) -> None:
        self._last_zone = None

    def _get_zone(self, price: float) -> int:
        """Return the grid zone index for a given price.

        Zone 0 is below the lowest level, zone grid_count is above the highest level.
        Zones 1..grid_count-1 are between levels.
        """
        for i, level in enumerate(self._grid_levels):
            if price < level:
                return i
        return len(self._grid_levels)

    def on_candle(self, candle: Candle) -> Signal | None:
        price = float(candle.close)
        current_zone = self._get_zone(price)

        if self._last_zone is None:
            self._last_zone = current_zone
            return None

        prev_zone = self._last_zone
        self._last_zone = current_zone

        if current_zone < prev_zone:
            # Price moved to a lower zone → BUY
            return Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.BUY,
                price=candle.close,
                quantity=self._quantity,
                reason=f"Grid: price crossed down from zone {prev_zone} to {current_zone}",
            )
        elif current_zone > prev_zone:
            # Price moved to a higher zone → SELL
            return Signal(
                strategy=self.name,
                symbol=candle.symbol,
                side=OrderSide.SELL,
                price=candle.close,
                quantity=self._quantity,
                reason=f"Grid: price crossed up from zone {prev_zone} to {current_zone}",
            )

        return None