diff options
Diffstat (limited to 'services/order-executor')
| -rw-r--r-- | services/order-executor/pyproject.toml | 2 | ||||
| -rw-r--r-- | services/order-executor/src/order_executor/executor.py | 14 | ||||
| -rw-r--r-- | services/order-executor/src/order_executor/main.py | 80 | ||||
| -rw-r--r-- | services/order-executor/src/order_executor/risk_manager.py | 251 | ||||
| -rw-r--r-- | services/order-executor/tests/test_executor.py | 20 | ||||
| -rw-r--r-- | services/order-executor/tests/test_risk_manager.py | 198 |
6 files changed, 487 insertions, 78 deletions
diff --git a/services/order-executor/pyproject.toml b/services/order-executor/pyproject.toml index eed4fef..7bb1030 100644 --- a/services/order-executor/pyproject.toml +++ b/services/order-executor/pyproject.toml @@ -3,7 +3,7 @@ name = "order-executor" version = "0.1.0" description = "Order execution service with risk management" requires-python = ">=3.12" -dependencies = ["ccxt>=4.0", "trading-shared"] +dependencies = ["trading-shared"] [project.optional-dependencies] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py index 80f441d..a71e762 100644 --- a/services/order-executor/src/order_executor/executor.py +++ b/services/order-executor/src/order_executor/executor.py @@ -37,12 +37,8 @@ class OrderExecutor: async def execute(self, signal: Signal) -> Optional[Order]: """Run risk checks and place an order for the given signal.""" - # Fetch current balance from exchange - balance_data = await self.exchange.fetch_balance() - # Use USDT (or quote currency) free balance as available capital - free_balances = balance_data.get("free", {}) - quote_currency = signal.symbol.split("/")[-1] if "/" in signal.symbol else "USDT" - balance = Decimal(str(free_balances.get(quote_currency, 0))) + # Fetch buying power from Alpaca + balance = await self.exchange.get_buying_power() # Fetch current positions positions = {} @@ -84,11 +80,11 @@ class OrderExecutor: ) else: try: - await self.exchange.create_order( + await self.exchange.submit_order( symbol=signal.symbol, - type="market", + qty=float(signal.quantity), side=signal.side.value.lower(), - amount=float(signal.quantity), + type="market", ) order.status = OrderStatus.FILLED order.filled_at = datetime.now(timezone.utc) diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py index 3fe4c12..51ab286 100644 --- a/services/order-executor/src/order_executor/main.py +++ b/services/order-executor/src/order_executor/main.py @@ -3,11 +3,11 @@ import asyncio from decimal import Decimal +from shared.alpaca import AlpacaClient from shared.broker import RedisBroker from shared.db import Database from shared.events import Event, EventType from shared.healthcheck import HealthCheckServer -from shared.exchange import create_exchange from shared.logging import setup_logging from shared.metrics import ServiceMetrics from shared.notifier import TelegramNotifier @@ -16,9 +16,7 @@ from order_executor.config import ExecutorConfig from order_executor.executor import OrderExecutor from order_executor.risk_manager import RiskManager -# Health check port: base (HEALTH_PORT, default 8080) + offset -# data-collector: +0 (8080), strategy-engine: +1 (8081) -# order-executor: +2 (8082), portfolio-manager: +3 (8083) +# Health check port: base + 2 HEALTH_PORT_OFFSET = 2 @@ -26,21 +24,21 @@ async def run() -> None: config = ExecutorConfig() log = setup_logging("order-executor", config.log_level, config.log_format) metrics = ServiceMetrics("order_executor") + notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + bot_token=config.telegram_bot_token, + chat_id=config.telegram_chat_id, ) db = Database(config.database_url) await db.connect() - await db.init_tables() broker = RedisBroker(config.redis_url) - exchange = create_exchange( - exchange_id=config.exchange_id, - api_key=config.binance_api_key, - api_secret=config.binance_api_secret, - sandbox=config.exchange_sandbox, + alpaca = AlpacaClient( + api_key=config.alpaca_api_key, + api_secret=config.alpaca_api_secret, + paper=config.alpaca_paper, ) risk_manager = RiskManager( @@ -51,10 +49,19 @@ async def run() -> None: max_open_positions=config.risk_max_open_positions, volatility_lookback=config.risk_volatility_lookback, volatility_scale=config.risk_volatility_scale, + max_portfolio_exposure=config.risk_max_portfolio_exposure, + max_correlated_exposure=config.risk_max_correlated_exposure, + correlation_threshold=config.risk_correlation_threshold, + var_confidence=config.risk_var_confidence, + var_limit_pct=config.risk_var_limit_pct, + drawdown_reduction_threshold=config.risk_drawdown_reduction_threshold, + drawdown_halt_threshold=config.risk_drawdown_halt_threshold, + max_consecutive_losses=config.risk_max_consecutive_losses, + loss_pause_minutes=config.risk_loss_pause_minutes, ) executor = OrderExecutor( - exchange=exchange, + exchange=alpaca, risk_manager=risk_manager, broker=broker, db=db, @@ -62,41 +69,34 @@ async def run() -> None: dry_run=config.dry_run, ) - GROUP = "order-executor" - CONSUMER = "executor-1" - stream = "signals" - health = HealthCheckServer( "order-executor", port=config.health_port + HEALTH_PORT_OFFSET, auth_token=config.metrics_auth_token, ) - health.register_check("redis", broker.ping) await health.start() metrics.service_up.labels(service="order-executor").set(1) - log.info("service_started", stream=stream, dry_run=config.dry_run) + GROUP = "order-executor" + CONSUMER = "executor-1" + stream = "signals" await broker.ensure_group(stream, GROUP) - # Process pending messages first (from previous crash) - pending = await broker.read_pending(stream, GROUP, CONSUMER) - for msg_id, msg in pending: - try: - event = Event.from_dict(msg) - if event.type == EventType.SIGNAL: - signal = event.data - log.info( - "processing_pending_signal", signal_id=str(signal.id), symbol=signal.symbol - ) - await executor.execute(signal) - metrics.events_processed.labels(service="order-executor", event_type="signal").inc() - await broker.ack(stream, GROUP, msg_id) - except Exception as exc: - log.error("pending_process_failed", error=str(exc), msg_id=msg_id) - metrics.errors_total.labels(service="order-executor", error_type="processing").inc() + log.info("started", stream=stream, dry_run=config.dry_run) try: + # Process pending messages first + pending = await broker.read_pending(stream, GROUP, CONSUMER) + for msg_id, msg in pending: + try: + event = Event.from_dict(msg) + if event.type == EventType.SIGNAL: + await executor.execute(event.data) + await broker.ack(stream, GROUP, msg_id) + except Exception as exc: + log.error("pending_failed", error=str(exc), msg_id=msg_id) + while True: messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000) for msg_id, msg in messages: @@ -104,29 +104,23 @@ async def run() -> None: event = Event.from_dict(msg) if event.type == EventType.SIGNAL: signal = event.data - log.info( - "processing_signal", signal_id=str(signal.id), symbol=signal.symbol - ) + log.info("processing_signal", signal_id=signal.id, symbol=signal.symbol) await executor.execute(signal) metrics.events_processed.labels( service="order-executor", event_type="signal" ).inc() await broker.ack(stream, GROUP, msg_id) except Exception as exc: - log.error("message_processing_failed", error=str(exc), msg_id=msg_id) + log.error("process_failed", error=str(exc)) metrics.errors_total.labels( service="order-executor", error_type="processing" ).inc() - except Exception as exc: - log.error("fatal_error", error=str(exc)) - await notifier.send_error(str(exc), "order-executor") - raise finally: metrics.service_up.labels(service="order-executor").set(0) await notifier.close() await broker.close() await db.close() - await exchange.close() + await alpaca.close() def main() -> None: diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py index c3578a7..5a05746 100644 --- a/services/order-executor/src/order_executor/risk_manager.py +++ b/services/order-executor/src/order_executor/risk_manager.py @@ -1,6 +1,7 @@ """Risk management for order execution.""" from dataclasses import dataclass +from datetime import datetime, timezone, timedelta from decimal import Decimal from collections import deque import math @@ -46,6 +47,15 @@ class RiskManager: max_open_positions: int = 10, volatility_lookback: int = 20, volatility_scale: bool = False, + max_portfolio_exposure: float = 0.8, + max_correlated_exposure: float = 0.5, + correlation_threshold: float = 0.7, + var_confidence: float = 0.95, + var_limit_pct: float = 5.0, + drawdown_reduction_threshold: float = 0.1, # Start reducing at 10% drawdown + drawdown_halt_threshold: float = 0.2, # Halt trading at 20% drawdown + max_consecutive_losses: int = 5, # Pause after 5 consecutive losses + loss_pause_minutes: int = 60, # Pause for 60 minutes after consecutive losses ) -> None: self.max_position_size = max_position_size self.stop_loss_pct = stop_loss_pct @@ -57,6 +67,75 @@ class RiskManager: self._trailing_stops: dict[str, TrailingStop] = {} self._price_history: dict[str, deque[float]] = {} + self._return_history: dict[str, list[float]] = {} + self._max_portfolio_exposure = Decimal(str(max_portfolio_exposure)) + self._max_correlated_exposure = Decimal(str(max_correlated_exposure)) + self._correlation_threshold = correlation_threshold + self._var_confidence = var_confidence + self._var_limit_pct = Decimal(str(var_limit_pct)) + + self._drawdown_reduction_threshold = drawdown_reduction_threshold + self._drawdown_halt_threshold = drawdown_halt_threshold + self._max_consecutive_losses = max_consecutive_losses + self._loss_pause_minutes = loss_pause_minutes + + self._peak_balance: Decimal = Decimal("0") + self._consecutive_losses: int = 0 + self._paused_until: datetime | None = None + + def update_balance(self, current_balance: Decimal) -> None: + """Track peak balance for drawdown calculation.""" + if current_balance > self._peak_balance: + self._peak_balance = current_balance + + def get_current_drawdown(self, current_balance: Decimal) -> float: + """Calculate current drawdown from peak as a fraction (0.0 to 1.0).""" + if self._peak_balance <= 0: + return 0.0 + dd = float((self._peak_balance - current_balance) / self._peak_balance) + return max(dd, 0.0) + + def get_position_scale(self, current_balance: Decimal) -> float: + """Get position size multiplier based on current drawdown. + + Returns 1.0 (full size) when no drawdown. + Linearly reduces to 0.25 between reduction threshold and halt threshold. + Returns 0.0 at or beyond halt threshold. + """ + dd = self.get_current_drawdown(current_balance) + + if dd >= self._drawdown_halt_threshold: + return 0.0 + + if dd >= self._drawdown_reduction_threshold: + # Linear interpolation from 1.0 to 0.25 + range_pct = (dd - self._drawdown_reduction_threshold) / ( + self._drawdown_halt_threshold - self._drawdown_reduction_threshold + ) + return max(1.0 - 0.75 * range_pct, 0.25) + + return 1.0 + + def record_trade_result(self, is_win: bool) -> None: + """Record a trade result for consecutive loss tracking.""" + if is_win: + self._consecutive_losses = 0 + else: + self._consecutive_losses += 1 + if self._consecutive_losses >= self._max_consecutive_losses: + self._paused_until = datetime.now(timezone.utc) + timedelta( + minutes=self._loss_pause_minutes + ) + + def is_paused(self) -> bool: + """Check if trading is paused due to consecutive losses.""" + if self._paused_until is None: + return False + if datetime.now(timezone.utc) >= self._paused_until: + self._paused_until = None + self._consecutive_losses = 0 + return False + return True def update_price(self, symbol: str, price: Decimal) -> None: """Update price tracking for trailing stops and volatility.""" @@ -120,6 +199,145 @@ class RiskManager: scale = min(target_vol / vol, 2.0) # Cap at 2x return base_size * Decimal(str(scale)) + def calculate_correlation(self, symbol_a: str, symbol_b: str) -> float | None: + """Calculate Pearson correlation between two symbols' returns.""" + hist_a = self._price_history.get(symbol_a) + hist_b = self._price_history.get(symbol_b) + if not hist_a or not hist_b or len(hist_a) < 5 or len(hist_b) < 5: + return None + + prices_a = list(hist_a) + prices_b = list(hist_b) + min_len = min(len(prices_a), len(prices_b)) + prices_a = prices_a[-min_len:] + prices_b = prices_b[-min_len:] + + returns_a = [ + (prices_a[i] - prices_a[i - 1]) / prices_a[i - 1] + for i in range(1, len(prices_a)) + if prices_a[i - 1] != 0 + ] + returns_b = [ + (prices_b[i] - prices_b[i - 1]) / prices_b[i - 1] + for i in range(1, len(prices_b)) + if prices_b[i - 1] != 0 + ] + + if len(returns_a) < 3 or len(returns_b) < 3: + return None + + min_len = min(len(returns_a), len(returns_b)) + returns_a = returns_a[-min_len:] + returns_b = returns_b[-min_len:] + + mean_a = sum(returns_a) / len(returns_a) + mean_b = sum(returns_b) / len(returns_b) + + cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len( + returns_a + ) + std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a)) + std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b)) + + if std_a == 0 or std_b == 0: + return None + + return cov / (std_a * std_b) + + def calculate_portfolio_var(self, positions: dict[str, Position], balance: Decimal) -> float: + """Calculate portfolio VaR using historical simulation. + + Returns VaR as a percentage of balance (e.g., 3.5 for 3.5%). + """ + if not positions or balance <= 0: + return 0.0 + + # Collect returns for all positioned symbols + all_returns: list[list[float]] = [] + weights: list[float] = [] + + for symbol, pos in positions.items(): + if pos.quantity <= 0: + continue + hist = self._price_history.get(symbol) + if not hist or len(hist) < 5: + continue + prices = list(hist) + returns = [ + (prices[i] - prices[i - 1]) / prices[i - 1] + for i in range(1, len(prices)) + if prices[i - 1] != 0 + ] + if returns: + all_returns.append(returns) + weight = float(pos.quantity * pos.current_price / balance) + weights.append(weight) + + if not all_returns: + return 0.0 + + # Portfolio returns (weighted sum) + min_len = min(len(r) for r in all_returns) + portfolio_returns = [] + for i in range(min_len): + pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i) + portfolio_returns.append(pr) + + if not portfolio_returns: + return 0.0 + + # Historical VaR: sort returns, take the (1-confidence) percentile + sorted_returns = sorted(portfolio_returns) + index = int((1 - self._var_confidence) * len(sorted_returns)) + index = max(0, min(index, len(sorted_returns) - 1)) + var_return = sorted_returns[index] + + return abs(var_return) * 100 # As percentage + + def check_portfolio_exposure( + self, positions: dict[str, Position], balance: Decimal + ) -> RiskCheckResult: + """Check total portfolio exposure.""" + if balance <= 0: + return RiskCheckResult(allowed=True, reason="OK") + + total_exposure = sum( + pos.quantity * pos.current_price for pos in positions.values() if pos.quantity > 0 + ) + + exposure_ratio = total_exposure / balance + if exposure_ratio > self._max_portfolio_exposure: + return RiskCheckResult( + allowed=False, + reason=f"Portfolio exposure {float(exposure_ratio):.1%} exceeds max {float(self._max_portfolio_exposure):.1%}", + ) + + return RiskCheckResult(allowed=True, reason="OK") + + def check_correlation_risk( + self, signal: Signal, positions: dict[str, Position], balance: Decimal + ) -> RiskCheckResult: + """Check if adding this position creates too much correlated exposure.""" + if signal.side != OrderSide.BUY or balance <= 0: + return RiskCheckResult(allowed=True, reason="OK") + + correlated_value = signal.price * signal.quantity + + for symbol, pos in positions.items(): + if pos.quantity <= 0 or symbol == signal.symbol: + continue + corr = self.calculate_correlation(signal.symbol, symbol) + if corr is not None and abs(corr) >= self._correlation_threshold: + correlated_value += pos.quantity * pos.current_price + + if correlated_value / balance > self._max_correlated_exposure: + return RiskCheckResult( + allowed=False, + reason=f"Correlated exposure would exceed {float(self._max_correlated_exposure):.1%}", + ) + + return RiskCheckResult(allowed=True, reason="OK") + def check( self, signal: Signal, @@ -128,6 +346,21 @@ class RiskManager: daily_pnl: Decimal, ) -> RiskCheckResult: """Run risk checks against a signal and current portfolio state.""" + # Check if paused due to consecutive losses + if self.is_paused(): + return RiskCheckResult( + allowed=False, + reason=f"Trading paused until {self._paused_until.isoformat()} after {self._max_consecutive_losses} consecutive losses", + ) + + # Check drawdown halt + dd = self.get_current_drawdown(balance) + if dd >= self._drawdown_halt_threshold: + return RiskCheckResult( + allowed=False, + reason=f"Trading halted: drawdown {dd:.1%} exceeds {self._drawdown_halt_threshold:.1%}", + ) + # Check daily loss limit if balance > 0 and (daily_pnl / balance) * 100 < -self.daily_loss_limit_pct: return RiskCheckResult(allowed=False, reason="Daily loss limit exceeded") @@ -165,4 +398,22 @@ class RiskManager: ): return RiskCheckResult(allowed=False, reason="Position size exceeded") + # Portfolio-level checks + exposure_check = self.check_portfolio_exposure(positions, balance) + if not exposure_check.allowed: + return exposure_check + + corr_check = self.check_correlation_risk(signal, positions, balance) + if not corr_check.allowed: + return corr_check + + # VaR check + if positions: + var = self.calculate_portfolio_var(positions, balance) + if var > float(self._var_limit_pct): + return RiskCheckResult( + allowed=False, + reason=f"Portfolio VaR {var:.1f}% exceeds limit {float(self._var_limit_pct):.1f}%", + ) + return RiskCheckResult(allowed=True, reason="OK") diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py index e64b6c0..dd823d7 100644 --- a/services/order-executor/tests/test_executor.py +++ b/services/order-executor/tests/test_executor.py @@ -13,7 +13,7 @@ from order_executor.risk_manager import RiskCheckResult, RiskManager def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal: return Signal( strategy="test", - symbol="BTC/USDT", + symbol="AAPL", side=side, price=Decimal(price), quantity=Decimal(quantity), @@ -21,10 +21,10 @@ def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: s ) -def make_mock_exchange(free_usdt: float = 10000.0) -> AsyncMock: +def make_mock_exchange(buying_power: str = "10000") -> AsyncMock: exchange = AsyncMock() - exchange.fetch_balance.return_value = {"free": {"USDT": free_usdt}} - exchange.create_order = AsyncMock(return_value={"id": "exchange-order-123"}) + exchange.get_buying_power = AsyncMock(return_value=Decimal(buying_power)) + exchange.submit_order = AsyncMock(return_value={"id": "alpaca-order-123"}) return exchange @@ -48,7 +48,7 @@ def make_mock_db() -> AsyncMock: @pytest.mark.asyncio async def test_executor_places_order_when_risk_passes(): - """When risk check passes, create_order is called and order status is FILLED.""" + """When risk check passes, submit_order is called and order status is FILLED.""" exchange = make_mock_exchange() risk_manager = make_mock_risk_manager(allowed=True) broker = make_mock_broker() @@ -68,14 +68,14 @@ async def test_executor_places_order_when_risk_passes(): assert order is not None assert order.status == OrderStatus.FILLED - exchange.create_order.assert_called_once() + exchange.submit_order.assert_called_once() db.insert_order.assert_called_once_with(order) broker.publish.assert_called_once() @pytest.mark.asyncio async def test_executor_rejects_when_risk_fails(): - """When risk check fails, create_order is not called and None is returned.""" + """When risk check fails, submit_order is not called and None is returned.""" exchange = make_mock_exchange() risk_manager = make_mock_risk_manager(allowed=False, reason="Position size exceeded") broker = make_mock_broker() @@ -94,14 +94,14 @@ async def test_executor_rejects_when_risk_fails(): order = await executor.execute(signal) assert order is None - exchange.create_order.assert_not_called() + exchange.submit_order.assert_not_called() db.insert_order.assert_not_called() broker.publish.assert_not_called() @pytest.mark.asyncio async def test_executor_dry_run_does_not_call_exchange(): - """In dry-run mode, risk passes, order is FILLED, but exchange.create_order is NOT called.""" + """In dry-run mode, risk passes, order is FILLED, but exchange.submit_order is NOT called.""" exchange = make_mock_exchange() risk_manager = make_mock_risk_manager(allowed=True) broker = make_mock_broker() @@ -121,6 +121,6 @@ async def test_executor_dry_run_does_not_call_exchange(): assert order is not None assert order.status == OrderStatus.FILLED - exchange.create_order.assert_not_called() + exchange.submit_order.assert_not_called() db.insert_order.assert_called_once_with(order) broker.publish.assert_called_once() diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py index efabe73..3d5175b 100644 --- a/services/order-executor/tests/test_risk_manager.py +++ b/services/order-executor/tests/test_risk_manager.py @@ -7,7 +7,7 @@ from shared.models import OrderSide, Position, Signal from order_executor.risk_manager import RiskManager -def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "BTC/USDT") -> Signal: +def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal: return Signal( strategy="test", symbol=symbol, @@ -93,7 +93,7 @@ def test_risk_check_rejects_insufficient_balance(): def test_trailing_stop_set_and_trigger(): """Trailing stop should trigger when price drops below stop level.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) signal = make_signal(side=OrderSide.BUY, price="94", quantity="0.01") result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0")) @@ -104,10 +104,10 @@ def test_trailing_stop_set_and_trigger(): def test_trailing_stop_updates_highest_price(): """Trailing stop should track the highest price seen.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) # Price rises to 120 => stop at 114 - rm.update_price("BTC/USDT", Decimal("120")) + rm.update_price("AAPL", Decimal("120")) # Price at 115 is above stop (114), should be allowed signal = make_signal(side=OrderSide.BUY, price="115", quantity="0.01") @@ -124,7 +124,7 @@ def test_trailing_stop_updates_highest_price(): def test_trailing_stop_not_triggered_above_stop(): """Trailing stop should not trigger when price is above stop level.""" rm = make_risk_manager(trailing_stop_pct="5") - rm.set_trailing_stop("BTC/USDT", Decimal("100")) + rm.set_trailing_stop("AAPL", Decimal("100")) # Price at 96 is above stop (95), should be allowed signal = make_signal(side=OrderSide.BUY, price="96", quantity="0.01") @@ -140,11 +140,11 @@ def test_max_open_positions_check(): rm = make_risk_manager(max_open_positions=2) positions = { - "BTC/USDT": make_position("BTC/USDT", "1", "100", "100"), - "ETH/USDT": make_position("ETH/USDT", "10", "50", "50"), + "AAPL": make_position("AAPL", "1", "100", "100"), + "MSFT": make_position("MSFT", "10", "50", "50"), } - signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="SOL/USDT") + signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="TSLA") result = rm.check(signal, balance=Decimal("10000"), positions=positions, daily_pnl=Decimal("0")) assert result.allowed is False assert result.reason == "Max open positions reached" @@ -158,14 +158,14 @@ def test_volatility_calculation(): rm = make_risk_manager(volatility_lookback=5) # No history yet - assert rm.get_volatility("BTC/USDT") is None + assert rm.get_volatility("AAPL") is None # Feed prices prices = [100, 102, 98, 105, 101] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - vol = rm.get_volatility("BTC/USDT") + vol = rm.get_volatility("AAPL") assert vol is not None assert vol > 0 @@ -177,9 +177,9 @@ def test_position_size_with_volatility_scaling(): # Feed volatile prices prices = [100, 120, 80, 130, 70] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - size = rm.calculate_position_size("BTC/USDT", Decimal("10000")) + size = rm.calculate_position_size("AAPL", Decimal("10000")) base = Decimal("10000") * Decimal("0.1") # High volatility should reduce size below base @@ -192,9 +192,177 @@ def test_position_size_without_scaling(): prices = [100, 120, 80, 130, 70] for p in prices: - rm.update_price("BTC/USDT", Decimal(str(p))) + rm.update_price("AAPL", Decimal(str(p))) - size = rm.calculate_position_size("BTC/USDT", Decimal("10000")) + size = rm.calculate_position_size("AAPL", Decimal("10000")) base = Decimal("10000") * Decimal("0.1") assert size == base + + +# --- Portfolio exposure tests --- + + +def test_portfolio_exposure_check_passes(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_portfolio_exposure=0.8, + ) + positions = { + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("0.01"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("50000"), + ) + } + result = rm.check_portfolio_exposure(positions, Decimal("10000")) + assert result.allowed # 500/10000 = 5% < 80% + + +def test_portfolio_exposure_check_rejects(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_portfolio_exposure=0.3, + ) + positions = { + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("1"), + avg_entry_price=Decimal("50000"), + current_price=Decimal("50000"), + ) + } + result = rm.check_portfolio_exposure(positions, Decimal("10000")) + assert not result.allowed # 50000/10000 = 500% > 30% + + +def test_correlation_calculation(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + ) + # Feed identical price histories — correlation should be ~1.0 + for i in range(20): + rm.update_price("A", Decimal(str(100 + i))) + rm.update_price("B", Decimal(str(100 + i))) + corr = rm.calculate_correlation("A", "B") + assert corr is not None + assert corr > 0.9 + + +def test_var_calculation(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + ) + for i in range(30): + rm.update_price("AAPL", Decimal(str(100 + (i % 5) - 2))) + positions = { + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("1"), + avg_entry_price=Decimal("100"), + current_price=Decimal("100"), + ) + } + var = rm.calculate_portfolio_var(positions, Decimal("10000")) + assert var >= 0 # Non-negative + + +# --- Drawdown position scaling tests --- + + +def test_drawdown_position_scale_full(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_reduction_threshold=0.1, + drawdown_halt_threshold=0.2, + ) + rm.update_balance(Decimal("10000")) + scale = rm.get_position_scale(Decimal("10000")) + assert scale == 1.0 # No drawdown + + +def test_drawdown_position_scale_reduced(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_reduction_threshold=0.1, + drawdown_halt_threshold=0.2, + ) + rm.update_balance(Decimal("10000")) + scale = rm.get_position_scale(Decimal("8500")) # 15% drawdown (between 10% and 20%) + assert 0.25 < scale < 1.0 + + +def test_drawdown_halt(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_reduction_threshold=0.1, + drawdown_halt_threshold=0.2, + ) + rm.update_balance(Decimal("10000")) + scale = rm.get_position_scale(Decimal("7500")) # 25% drawdown + assert scale == 0.0 + + +def test_consecutive_losses_pause(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_consecutive_losses=3, + loss_pause_minutes=60, + ) + rm.record_trade_result(False) + rm.record_trade_result(False) + assert not rm.is_paused() + rm.record_trade_result(False) # 3rd loss + assert rm.is_paused() + + +def test_consecutive_losses_reset_on_win(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + max_consecutive_losses=3, + ) + rm.record_trade_result(False) + rm.record_trade_result(False) + rm.record_trade_result(True) # Win resets counter + rm.record_trade_result(False) + assert not rm.is_paused() # Only 1 consecutive loss + + +def test_drawdown_check_rejects_in_check(): + rm = RiskManager( + max_position_size=Decimal("0.5"), + stop_loss_pct=Decimal("5"), + daily_loss_limit_pct=Decimal("10"), + drawdown_halt_threshold=0.15, + ) + rm.update_balance(Decimal("10000")) + signal = Signal( + strategy="test", + symbol="AAPL", + side=OrderSide.BUY, + price=Decimal("50000"), + quantity=Decimal("0.01"), + reason="test", + ) + result = rm.check(signal, Decimal("8000"), {}, Decimal("0")) # 20% dd > 15% + assert not result.allowed + assert "halted" in result.reason.lower() |
