"""Risk management for order execution.""" from dataclasses import dataclass from datetime import datetime, timezone, timedelta from decimal import Decimal from collections import deque import math from shared.models import Signal, OrderSide, Position @dataclass class RiskCheckResult: allowed: bool reason: str @dataclass class TrailingStop: """Tracks trailing stop for a symbol.""" symbol: str highest_price: Decimal stop_pct: Decimal # e.g. 5.0 for 5% @property def stop_price(self) -> Decimal: return self.highest_price * (1 - self.stop_pct / 100) def update(self, current_price: Decimal) -> None: if current_price > self.highest_price: self.highest_price = current_price def is_triggered(self, current_price: Decimal) -> bool: return current_price <= self.stop_price class RiskManager: """Evaluates risk before order execution with advanced features.""" def __init__( self, max_position_size: Decimal, stop_loss_pct: Decimal, daily_loss_limit_pct: Decimal, trailing_stop_pct: Decimal = Decimal("0"), max_open_positions: int = 10, volatility_lookback: int = 20, volatility_scale: bool = False, max_portfolio_exposure: float = 0.8, max_correlated_exposure: float = 0.5, correlation_threshold: float = 0.7, var_confidence: float = 0.95, var_limit_pct: float = 5.0, drawdown_reduction_threshold: float = 0.1, # Start reducing at 10% drawdown drawdown_halt_threshold: float = 0.2, # Halt trading at 20% drawdown max_consecutive_losses: int = 5, # Pause after 5 consecutive losses loss_pause_minutes: int = 60, # Pause for 60 minutes after consecutive losses ) -> None: self.max_position_size = max_position_size self.stop_loss_pct = stop_loss_pct self.daily_loss_limit_pct = daily_loss_limit_pct self.trailing_stop_pct = trailing_stop_pct self.max_open_positions = max_open_positions self.volatility_lookback = volatility_lookback self.volatility_scale = volatility_scale self._trailing_stops: dict[str, TrailingStop] = {} self._price_history: dict[str, deque[float]] = {} self._return_history: dict[str, list[float]] = {} self._max_portfolio_exposure = Decimal(str(max_portfolio_exposure)) self._max_correlated_exposure = Decimal(str(max_correlated_exposure)) self._correlation_threshold = correlation_threshold self._var_confidence = var_confidence self._var_limit_pct = Decimal(str(var_limit_pct)) self._drawdown_reduction_threshold = drawdown_reduction_threshold self._drawdown_halt_threshold = drawdown_halt_threshold self._max_consecutive_losses = max_consecutive_losses self._loss_pause_minutes = loss_pause_minutes self._peak_balance: Decimal = Decimal("0") self._consecutive_losses: int = 0 self._paused_until: datetime | None = None def update_balance(self, current_balance: Decimal) -> None: """Track peak balance for drawdown calculation.""" if current_balance > self._peak_balance: self._peak_balance = current_balance def get_current_drawdown(self, current_balance: Decimal) -> float: """Calculate current drawdown from peak as a fraction (0.0 to 1.0).""" if self._peak_balance <= 0: return 0.0 dd = float((self._peak_balance - current_balance) / self._peak_balance) return max(dd, 0.0) def get_position_scale(self, current_balance: Decimal) -> float: """Get position size multiplier based on current drawdown. Returns 1.0 (full size) when no drawdown. Linearly reduces to 0.25 between reduction threshold and halt threshold. Returns 0.0 at or beyond halt threshold. """ dd = self.get_current_drawdown(current_balance) if dd >= self._drawdown_halt_threshold: return 0.0 if dd >= self._drawdown_reduction_threshold: # Linear interpolation from 1.0 to 0.25 range_pct = (dd - self._drawdown_reduction_threshold) / ( self._drawdown_halt_threshold - self._drawdown_reduction_threshold ) return max(1.0 - 0.75 * range_pct, 0.25) return 1.0 def record_trade_result(self, is_win: bool) -> None: """Record a trade result for consecutive loss tracking.""" if is_win: self._consecutive_losses = 0 else: self._consecutive_losses += 1 if self._consecutive_losses >= self._max_consecutive_losses: self._paused_until = datetime.now(timezone.utc) + timedelta( minutes=self._loss_pause_minutes ) def is_paused(self) -> bool: """Check if trading is paused due to consecutive losses.""" if self._paused_until is None: return False if datetime.now(timezone.utc) >= self._paused_until: self._paused_until = None self._consecutive_losses = 0 return False return True def update_price(self, symbol: str, price: Decimal) -> None: """Update price tracking for trailing stops and volatility.""" # Trailing stop if symbol in self._trailing_stops: self._trailing_stops[symbol].update(price) # Price history for volatility if symbol not in self._price_history: self._price_history[symbol] = deque(maxlen=self.volatility_lookback) self._price_history[symbol].append(float(price)) def set_trailing_stop(self, symbol: str, entry_price: Decimal) -> None: """Set a trailing stop for a new position.""" if self.trailing_stop_pct > 0: self._trailing_stops[symbol] = TrailingStop( symbol=symbol, highest_price=entry_price, stop_pct=self.trailing_stop_pct, ) def remove_trailing_stop(self, symbol: str) -> None: """Remove trailing stop when position is closed.""" self._trailing_stops.pop(symbol, None) def get_volatility(self, symbol: str) -> float | None: """Calculate annualized volatility for a symbol.""" history = self._price_history.get(symbol) if not history or len(history) < 2: return None prices = list(history) returns = [ (prices[i] - prices[i - 1]) / prices[i - 1] for i in range(1, len(prices)) if prices[i - 1] != 0 ] if not returns: return None mean = sum(returns) / len(returns) var = sum((r - mean) ** 2 for r in returns) / len(returns) daily_vol = math.sqrt(var) return daily_vol * math.sqrt(365) # Annualized def calculate_position_size(self, symbol: str, balance: Decimal) -> Decimal: """Calculate position size adjusted for volatility. Lower volatility -> larger position, higher volatility -> smaller position. Base: max_position_size of balance. Scaled by inverse volatility. """ base_size = balance * self.max_position_size if not self.volatility_scale: return base_size vol = self.get_volatility(symbol) if vol is None or vol == 0: return base_size # Target volatility of 20% annualized target_vol = 0.20 scale = min(target_vol / vol, 2.0) # Cap at 2x return base_size * Decimal(str(scale)) def calculate_correlation(self, symbol_a: str, symbol_b: str) -> float | None: """Calculate Pearson correlation between two symbols' returns.""" hist_a = self._price_history.get(symbol_a) hist_b = self._price_history.get(symbol_b) if not hist_a or not hist_b or len(hist_a) < 5 or len(hist_b) < 5: return None prices_a = list(hist_a) prices_b = list(hist_b) min_len = min(len(prices_a), len(prices_b)) prices_a = prices_a[-min_len:] prices_b = prices_b[-min_len:] returns_a = [ (prices_a[i] - prices_a[i - 1]) / prices_a[i - 1] for i in range(1, len(prices_a)) if prices_a[i - 1] != 0 ] returns_b = [ (prices_b[i] - prices_b[i - 1]) / prices_b[i - 1] for i in range(1, len(prices_b)) if prices_b[i - 1] != 0 ] if len(returns_a) < 3 or len(returns_b) < 3: return None min_len = min(len(returns_a), len(returns_b)) returns_a = returns_a[-min_len:] returns_b = returns_b[-min_len:] mean_a = sum(returns_a) / len(returns_a) mean_b = sum(returns_b) / len(returns_b) cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len( returns_a ) std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a)) std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b)) if std_a == 0 or std_b == 0: return None return cov / (std_a * std_b) def calculate_portfolio_var(self, positions: dict[str, Position], balance: Decimal) -> float: """Calculate portfolio VaR using historical simulation. Returns VaR as a percentage of balance (e.g., 3.5 for 3.5%). """ if not positions or balance <= 0: return 0.0 # Collect returns for all positioned symbols all_returns: list[list[float]] = [] weights: list[float] = [] for symbol, pos in positions.items(): if pos.quantity <= 0: continue hist = self._price_history.get(symbol) if not hist or len(hist) < 5: continue prices = list(hist) returns = [ (prices[i] - prices[i - 1]) / prices[i - 1] for i in range(1, len(prices)) if prices[i - 1] != 0 ] if returns: all_returns.append(returns) weight = float(pos.quantity * pos.current_price / balance) weights.append(weight) if not all_returns: return 0.0 # Portfolio returns (weighted sum) min_len = min(len(r) for r in all_returns) portfolio_returns = [] for i in range(min_len): pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i) portfolio_returns.append(pr) if not portfolio_returns: return 0.0 # Historical VaR: sort returns, take the (1-confidence) percentile sorted_returns = sorted(portfolio_returns) index = int((1 - self._var_confidence) * len(sorted_returns)) index = max(0, min(index, len(sorted_returns) - 1)) var_return = sorted_returns[index] return abs(var_return) * 100 # As percentage def check_portfolio_exposure( self, positions: dict[str, Position], balance: Decimal ) -> RiskCheckResult: """Check total portfolio exposure.""" if balance <= 0: return RiskCheckResult(allowed=True, reason="OK") total_exposure = sum( pos.quantity * pos.current_price for pos in positions.values() if pos.quantity > 0 ) exposure_ratio = total_exposure / balance if exposure_ratio > self._max_portfolio_exposure: return RiskCheckResult( allowed=False, reason=f"Portfolio exposure {float(exposure_ratio):.1%} exceeds max {float(self._max_portfolio_exposure):.1%}", ) return RiskCheckResult(allowed=True, reason="OK") def check_correlation_risk( self, signal: Signal, positions: dict[str, Position], balance: Decimal ) -> RiskCheckResult: """Check if adding this position creates too much correlated exposure.""" if signal.side != OrderSide.BUY or balance <= 0: return RiskCheckResult(allowed=True, reason="OK") correlated_value = signal.price * signal.quantity for symbol, pos in positions.items(): if pos.quantity <= 0 or symbol == signal.symbol: continue corr = self.calculate_correlation(signal.symbol, symbol) if corr is not None and abs(corr) >= self._correlation_threshold: correlated_value += pos.quantity * pos.current_price if correlated_value / balance > self._max_correlated_exposure: return RiskCheckResult( allowed=False, reason=f"Correlated exposure would exceed {float(self._max_correlated_exposure):.1%}", ) return RiskCheckResult(allowed=True, reason="OK") def check( self, signal: Signal, balance: Decimal, positions: dict[str, Position], daily_pnl: Decimal, ) -> RiskCheckResult: """Run risk checks against a signal and current portfolio state.""" # Check if paused due to consecutive losses if self.is_paused(): return RiskCheckResult( allowed=False, reason=f"Trading paused until {self._paused_until.isoformat()} after {self._max_consecutive_losses} consecutive losses", ) # Check drawdown halt dd = self.get_current_drawdown(balance) if dd >= self._drawdown_halt_threshold: return RiskCheckResult( allowed=False, reason=f"Trading halted: drawdown {dd:.1%} exceeds {self._drawdown_halt_threshold:.1%}", ) # Check daily loss limit if balance > 0 and (daily_pnl / balance) * 100 < -self.daily_loss_limit_pct: return RiskCheckResult(allowed=False, reason="Daily loss limit exceeded") # Check trailing stop if signal.side == OrderSide.BUY: trailing = self._trailing_stops.get(signal.symbol) if trailing and trailing.is_triggered(signal.price): return RiskCheckResult( allowed=False, reason=f"Trailing stop triggered at {trailing.stop_price}", ) if signal.side == OrderSide.BUY: order_cost = signal.price * signal.quantity # Check sufficient balance if order_cost > balance: return RiskCheckResult(allowed=False, reason="Insufficient balance") # Check max open positions open_count = sum(1 for p in positions.values() if p.quantity > 0) if open_count >= self.max_open_positions: return RiskCheckResult(allowed=False, reason="Max open positions reached") # Check position size limit position = positions.get(signal.symbol) current_position_value = Decimal(0) if position is not None: current_position_value = position.quantity * position.current_price if ( balance > 0 and (current_position_value + order_cost) / balance > self.max_position_size ): return RiskCheckResult(allowed=False, reason="Position size exceeded") # Portfolio-level checks exposure_check = self.check_portfolio_exposure(positions, balance) if not exposure_check.allowed: return exposure_check corr_check = self.check_correlation_risk(signal, positions, balance) if not corr_check.allowed: return corr_check # VaR check if positions: var = self.calculate_portfolio_var(positions, balance) if var > float(self._var_limit_pct): return RiskCheckResult( allowed=False, reason=f"Portfolio VaR {var:.1f}% exceeds limit {float(self._var_limit_pct):.1f}%", ) return RiskCheckResult(allowed=True, reason="OK")