summaryrefslogtreecommitdiff
path: root/services
diff options
context:
space:
mode:
Diffstat (limited to 'services')
-rw-r--r--services/order-executor/src/order_executor/config.py5
-rw-r--r--services/order-executor/src/order_executor/main.py4
-rw-r--r--services/order-executor/src/order_executor/risk_manager.py112
-rw-r--r--services/order-executor/tests/test_risk_manager.py132
4 files changed, 245 insertions, 8 deletions
diff --git a/services/order-executor/src/order_executor/config.py b/services/order-executor/src/order_executor/config.py
index 6542a31..14828ea 100644
--- a/services/order-executor/src/order_executor/config.py
+++ b/services/order-executor/src/order_executor/config.py
@@ -4,4 +4,7 @@ from shared.config import Settings
class ExecutorConfig(Settings):
- pass
+ risk_trailing_stop_pct: float = 0.0
+ risk_max_open_positions: int = 10
+ risk_volatility_lookback: int = 20
+ risk_volatility_scale: bool = False
diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py
index 24a166e..32470f6 100644
--- a/services/order-executor/src/order_executor/main.py
+++ b/services/order-executor/src/order_executor/main.py
@@ -42,6 +42,10 @@ async def run() -> None:
max_position_size=Decimal(str(config.risk_max_position_size)),
stop_loss_pct=Decimal(str(config.risk_stop_loss_pct)),
daily_loss_limit_pct=Decimal(str(config.risk_daily_loss_limit_pct)),
+ trailing_stop_pct=Decimal(str(config.risk_trailing_stop_pct)),
+ max_open_positions=config.risk_max_open_positions,
+ volatility_lookback=config.risk_volatility_lookback,
+ volatility_scale=config.risk_volatility_scale,
)
executor = OrderExecutor(
diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py
index db162e1..2b0a864 100644
--- a/services/order-executor/src/order_executor/risk_manager.py
+++ b/services/order-executor/src/order_executor/risk_manager.py
@@ -1,7 +1,8 @@
"""Risk management for order execution."""
-
from dataclasses import dataclass
from decimal import Decimal
+from collections import deque
+import math
from shared.models import Signal, OrderSide, Position
@@ -12,18 +13,106 @@ class RiskCheckResult:
reason: str
+@dataclass
+class TrailingStop:
+ """Tracks trailing stop for a symbol."""
+ symbol: str
+ highest_price: Decimal
+ stop_pct: Decimal # e.g. 5.0 for 5%
+
+ @property
+ def stop_price(self) -> Decimal:
+ return self.highest_price * (1 - self.stop_pct / 100)
+
+ def update(self, current_price: Decimal) -> None:
+ if current_price > self.highest_price:
+ self.highest_price = current_price
+
+ def is_triggered(self, current_price: Decimal) -> bool:
+ return current_price <= self.stop_price
+
+
class RiskManager:
- """Evaluates risk before order execution."""
+ """Evaluates risk before order execution with advanced features."""
def __init__(
self,
max_position_size: Decimal,
stop_loss_pct: Decimal,
daily_loss_limit_pct: Decimal,
+ trailing_stop_pct: Decimal = Decimal("0"),
+ max_open_positions: int = 10,
+ volatility_lookback: int = 20,
+ volatility_scale: bool = False,
) -> None:
self.max_position_size = max_position_size
self.stop_loss_pct = stop_loss_pct
self.daily_loss_limit_pct = daily_loss_limit_pct
+ self.trailing_stop_pct = trailing_stop_pct
+ self.max_open_positions = max_open_positions
+ self.volatility_lookback = volatility_lookback
+ self.volatility_scale = volatility_scale
+
+ self._trailing_stops: dict[str, TrailingStop] = {}
+ self._price_history: dict[str, deque[float]] = {}
+
+ def update_price(self, symbol: str, price: Decimal) -> None:
+ """Update price tracking for trailing stops and volatility."""
+ # Trailing stop
+ if symbol in self._trailing_stops:
+ self._trailing_stops[symbol].update(price)
+
+ # Price history for volatility
+ if symbol not in self._price_history:
+ self._price_history[symbol] = deque(maxlen=self.volatility_lookback)
+ self._price_history[symbol].append(float(price))
+
+ def set_trailing_stop(self, symbol: str, entry_price: Decimal) -> None:
+ """Set a trailing stop for a new position."""
+ if self.trailing_stop_pct > 0:
+ self._trailing_stops[symbol] = TrailingStop(
+ symbol=symbol,
+ highest_price=entry_price,
+ stop_pct=self.trailing_stop_pct,
+ )
+
+ def remove_trailing_stop(self, symbol: str) -> None:
+ """Remove trailing stop when position is closed."""
+ self._trailing_stops.pop(symbol, None)
+
+ def get_volatility(self, symbol: str) -> float | None:
+ """Calculate annualized volatility for a symbol."""
+ history = self._price_history.get(symbol)
+ if not history or len(history) < 2:
+ return None
+ prices = list(history)
+ returns = [(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices)) if prices[i-1] != 0]
+ if not returns:
+ return None
+ mean = sum(returns) / len(returns)
+ var = sum((r - mean) ** 2 for r in returns) / len(returns)
+ daily_vol = math.sqrt(var)
+ return daily_vol * math.sqrt(365) # Annualized
+
+ def calculate_position_size(self, symbol: str, balance: Decimal) -> Decimal:
+ """Calculate position size adjusted for volatility.
+
+ Lower volatility -> larger position, higher volatility -> smaller position.
+ Base: max_position_size of balance. Scaled by inverse volatility.
+ """
+ base_size = balance * self.max_position_size
+
+ if not self.volatility_scale:
+ return base_size
+
+ vol = self.get_volatility(symbol)
+ if vol is None or vol == 0:
+ return base_size
+
+ # Target volatility of 20% annualized
+ target_vol = 0.20
+ scale = min(target_vol / vol, 2.0) # Cap at 2x
+ return base_size * Decimal(str(scale))
def check(
self,
@@ -37,6 +126,15 @@ class RiskManager:
if balance > 0 and (daily_pnl / balance) * 100 < -self.daily_loss_limit_pct:
return RiskCheckResult(allowed=False, reason="Daily loss limit exceeded")
+ # Check trailing stop
+ if signal.side == OrderSide.BUY:
+ trailing = self._trailing_stops.get(signal.symbol)
+ if trailing and trailing.is_triggered(signal.price):
+ return RiskCheckResult(
+ allowed=False,
+ reason=f"Trailing stop triggered at {trailing.stop_price}",
+ )
+
if signal.side == OrderSide.BUY:
order_cost = signal.price * signal.quantity
@@ -44,16 +142,18 @@ class RiskManager:
if order_cost > balance:
return RiskCheckResult(allowed=False, reason="Insufficient balance")
+ # Check max open positions
+ open_count = sum(1 for p in positions.values() if p.quantity > 0)
+ if open_count >= self.max_open_positions:
+ return RiskCheckResult(allowed=False, reason="Max open positions reached")
+
# Check position size limit
position = positions.get(signal.symbol)
current_position_value = Decimal(0)
if position is not None:
current_position_value = position.quantity * position.current_price
- if (
- balance > 0
- and (current_position_value + order_cost) / balance > self.max_position_size
- ):
+ if balance > 0 and (current_position_value + order_cost) / balance > self.max_position_size:
return RiskCheckResult(allowed=False, reason="Position size exceeded")
return RiskCheckResult(allowed=True, reason="OK")
diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py
index a122d16..efabe73 100644
--- a/services/order-executor/tests/test_risk_manager.py
+++ b/services/order-executor/tests/test_risk_manager.py
@@ -3,7 +3,7 @@
from decimal import Decimal
-from shared.models import OrderSide, Signal
+from shared.models import OrderSide, Position, Signal
from order_executor.risk_manager import RiskManager
@@ -22,11 +22,28 @@ def make_risk_manager(
max_position_size: str = "0.1",
stop_loss_pct: str = "5.0",
daily_loss_limit_pct: str = "10.0",
+ trailing_stop_pct: str = "0",
+ max_open_positions: int = 10,
+ volatility_lookback: int = 20,
+ volatility_scale: bool = False,
) -> RiskManager:
return RiskManager(
max_position_size=Decimal(max_position_size),
stop_loss_pct=Decimal(stop_loss_pct),
daily_loss_limit_pct=Decimal(daily_loss_limit_pct),
+ trailing_stop_pct=Decimal(trailing_stop_pct),
+ max_open_positions=max_open_positions,
+ volatility_lookback=volatility_lookback,
+ volatility_scale=volatility_scale,
+ )
+
+
+def make_position(symbol: str, quantity: str, avg_entry: str, current: str) -> Position:
+ return Position(
+ symbol=symbol,
+ quantity=Decimal(quantity),
+ avg_entry_price=Decimal(avg_entry),
+ current_price=Decimal(current),
)
@@ -68,3 +85,116 @@ def test_risk_check_rejects_insufficient_balance():
result = rm.check(signal, balance=Decimal("100"), positions={}, daily_pnl=Decimal("0"))
assert result.allowed is False
assert result.reason == "Insufficient balance"
+
+
+# --- Trailing stop tests ---
+
+
+def test_trailing_stop_set_and_trigger():
+ """Trailing stop should trigger when price drops below stop level."""
+ rm = make_risk_manager(trailing_stop_pct="5")
+ rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+
+ signal = make_signal(side=OrderSide.BUY, price="94", quantity="0.01")
+ result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0"))
+ assert result.allowed is False
+ assert "Trailing stop triggered" in result.reason
+
+
+def test_trailing_stop_updates_highest_price():
+ """Trailing stop should track the highest price seen."""
+ rm = make_risk_manager(trailing_stop_pct="5")
+ rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+
+ # Price rises to 120 => stop at 114
+ rm.update_price("BTC/USDT", Decimal("120"))
+
+ # Price at 115 is above stop (114), should be allowed
+ signal = make_signal(side=OrderSide.BUY, price="115", quantity="0.01")
+ result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0"))
+ assert result.allowed is True
+
+ # Price at 113 is below stop (114), should be rejected
+ signal = make_signal(side=OrderSide.BUY, price="113", quantity="0.01")
+ result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0"))
+ assert result.allowed is False
+ assert "Trailing stop triggered" in result.reason
+
+
+def test_trailing_stop_not_triggered_above_stop():
+ """Trailing stop should not trigger when price is above stop level."""
+ rm = make_risk_manager(trailing_stop_pct="5")
+ rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+
+ # Price at 96 is above stop (95), should be allowed
+ signal = make_signal(side=OrderSide.BUY, price="96", quantity="0.01")
+ result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0"))
+ assert result.allowed is True
+
+
+# --- Max open positions test ---
+
+
+def test_max_open_positions_check():
+ """Should reject new BUY when max open positions is reached."""
+ rm = make_risk_manager(max_open_positions=2)
+
+ positions = {
+ "BTC/USDT": make_position("BTC/USDT", "1", "100", "100"),
+ "ETH/USDT": make_position("ETH/USDT", "10", "50", "50"),
+ }
+
+ signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="SOL/USDT")
+ result = rm.check(signal, balance=Decimal("10000"), positions=positions, daily_pnl=Decimal("0"))
+ assert result.allowed is False
+ assert result.reason == "Max open positions reached"
+
+
+# --- Volatility tests ---
+
+
+def test_volatility_calculation():
+ """Volatility should be calculated from price history."""
+ rm = make_risk_manager(volatility_lookback=5)
+
+ # No history yet
+ assert rm.get_volatility("BTC/USDT") is None
+
+ # Feed prices
+ prices = [100, 102, 98, 105, 101]
+ for p in prices:
+ rm.update_price("BTC/USDT", Decimal(str(p)))
+
+ vol = rm.get_volatility("BTC/USDT")
+ assert vol is not None
+ assert vol > 0
+
+
+def test_position_size_with_volatility_scaling():
+ """High volatility should reduce position size."""
+ rm = make_risk_manager(volatility_scale=True, volatility_lookback=5)
+
+ # Feed volatile prices
+ prices = [100, 120, 80, 130, 70]
+ for p in prices:
+ rm.update_price("BTC/USDT", Decimal(str(p)))
+
+ size = rm.calculate_position_size("BTC/USDT", Decimal("10000"))
+ base = Decimal("10000") * Decimal("0.1")
+
+ # High volatility should reduce size below base
+ assert size < base
+
+
+def test_position_size_without_scaling():
+ """Without scaling, position size should be base size regardless of volatility."""
+ rm = make_risk_manager(volatility_scale=False, volatility_lookback=5)
+
+ prices = [100, 120, 80, 130, 70]
+ for p in prices:
+ rm.update_price("BTC/USDT", Decimal(str(p)))
+
+ size = rm.calculate_position_size("BTC/USDT", Decimal("10000"))
+ base = Decimal("10000") * Decimal("0.1")
+
+ assert size == base