from decimal import Decimal from typing import Optional import numpy as np from shared.models import Candle, Signal, OrderSide from strategies.base import BaseStrategy class GridStrategy(BaseStrategy): name: str = "grid" def __init__(self) -> None: self._lower_price: float = 0.0 self._upper_price: float = 0.0 self._grid_count: int = 5 self._quantity: Decimal = Decimal("0.01") self._grid_levels: list[float] = [] self._last_zone: Optional[int] = None @property def warmup_period(self) -> int: return 2 def configure(self, params: dict) -> None: self._lower_price = float(params["lower_price"]) self._upper_price = float(params["upper_price"]) self._grid_count = int(params.get("grid_count", 5)) self._quantity = Decimal(str(params.get("quantity", "0.01"))) if self._lower_price >= self._upper_price: raise ValueError( f"Grid lower_price must be < upper_price, " f"got lower={self._lower_price}, upper={self._upper_price}" ) if self._grid_count < 2: raise ValueError(f"Grid grid_count must be >= 2, got {self._grid_count}") if self._quantity <= 0: raise ValueError(f"Quantity must be positive, got {self._quantity}") self._grid_levels = list( np.linspace(self._lower_price, self._upper_price, self._grid_count + 1) ) self._last_zone = None def reset(self) -> None: self._last_zone = None def _get_zone(self, price: float) -> int: """Return the grid zone index for a given price. Zone 0 is below the lowest level, zone grid_count is above the highest level. Zones 1..grid_count-1 are between levels. """ for i, level in enumerate(self._grid_levels): if price < level: return i return len(self._grid_levels) def on_candle(self, candle: Candle) -> Signal | None: price = float(candle.close) current_zone = self._get_zone(price) if self._last_zone is None: self._last_zone = current_zone return None prev_zone = self._last_zone self._last_zone = current_zone if current_zone < prev_zone: # Price moved to a lower zone → BUY return Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.BUY, price=candle.close, quantity=self._quantity, reason=f"Grid: price crossed down from zone {prev_zone} to {current_zone}", ) elif current_zone > prev_zone: # Price moved to a higher zone → SELL return Signal( strategy=self.name, symbol=candle.symbol, side=OrderSide.SELL, price=candle.close, quantity=self._quantity, reason=f"Grid: price crossed up from zone {prev_zone} to {current_zone}", ) return None