summaryrefslogtreecommitdiff
path: root/services/order-executor
diff options
context:
space:
mode:
Diffstat (limited to 'services/order-executor')
-rw-r--r--services/order-executor/pyproject.toml2
-rw-r--r--services/order-executor/src/order_executor/executor.py14
-rw-r--r--services/order-executor/src/order_executor/main.py80
-rw-r--r--services/order-executor/src/order_executor/risk_manager.py251
-rw-r--r--services/order-executor/tests/test_executor.py20
-rw-r--r--services/order-executor/tests/test_risk_manager.py198
6 files changed, 487 insertions, 78 deletions
diff --git a/services/order-executor/pyproject.toml b/services/order-executor/pyproject.toml
index eed4fef..7bb1030 100644
--- a/services/order-executor/pyproject.toml
+++ b/services/order-executor/pyproject.toml
@@ -3,7 +3,7 @@ name = "order-executor"
version = "0.1.0"
description = "Order execution service with risk management"
requires-python = ">=3.12"
-dependencies = ["ccxt>=4.0", "trading-shared"]
+dependencies = ["trading-shared"]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py
index 80f441d..a71e762 100644
--- a/services/order-executor/src/order_executor/executor.py
+++ b/services/order-executor/src/order_executor/executor.py
@@ -37,12 +37,8 @@ class OrderExecutor:
async def execute(self, signal: Signal) -> Optional[Order]:
"""Run risk checks and place an order for the given signal."""
- # Fetch current balance from exchange
- balance_data = await self.exchange.fetch_balance()
- # Use USDT (or quote currency) free balance as available capital
- free_balances = balance_data.get("free", {})
- quote_currency = signal.symbol.split("/")[-1] if "/" in signal.symbol else "USDT"
- balance = Decimal(str(free_balances.get(quote_currency, 0)))
+ # Fetch buying power from Alpaca
+ balance = await self.exchange.get_buying_power()
# Fetch current positions
positions = {}
@@ -84,11 +80,11 @@ class OrderExecutor:
)
else:
try:
- await self.exchange.create_order(
+ await self.exchange.submit_order(
symbol=signal.symbol,
- type="market",
+ qty=float(signal.quantity),
side=signal.side.value.lower(),
- amount=float(signal.quantity),
+ type="market",
)
order.status = OrderStatus.FILLED
order.filled_at = datetime.now(timezone.utc)
diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py
index 3fe4c12..51ab286 100644
--- a/services/order-executor/src/order_executor/main.py
+++ b/services/order-executor/src/order_executor/main.py
@@ -3,11 +3,11 @@
import asyncio
from decimal import Decimal
+from shared.alpaca import AlpacaClient
from shared.broker import RedisBroker
from shared.db import Database
from shared.events import Event, EventType
from shared.healthcheck import HealthCheckServer
-from shared.exchange import create_exchange
from shared.logging import setup_logging
from shared.metrics import ServiceMetrics
from shared.notifier import TelegramNotifier
@@ -16,9 +16,7 @@ from order_executor.config import ExecutorConfig
from order_executor.executor import OrderExecutor
from order_executor.risk_manager import RiskManager
-# Health check port: base (HEALTH_PORT, default 8080) + offset
-# data-collector: +0 (8080), strategy-engine: +1 (8081)
-# order-executor: +2 (8082), portfolio-manager: +3 (8083)
+# Health check port: base + 2
HEALTH_PORT_OFFSET = 2
@@ -26,21 +24,21 @@ async def run() -> None:
config = ExecutorConfig()
log = setup_logging("order-executor", config.log_level, config.log_format)
metrics = ServiceMetrics("order_executor")
+
notifier = TelegramNotifier(
- bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id
+ bot_token=config.telegram_bot_token,
+ chat_id=config.telegram_chat_id,
)
db = Database(config.database_url)
await db.connect()
- await db.init_tables()
broker = RedisBroker(config.redis_url)
- exchange = create_exchange(
- exchange_id=config.exchange_id,
- api_key=config.binance_api_key,
- api_secret=config.binance_api_secret,
- sandbox=config.exchange_sandbox,
+ alpaca = AlpacaClient(
+ api_key=config.alpaca_api_key,
+ api_secret=config.alpaca_api_secret,
+ paper=config.alpaca_paper,
)
risk_manager = RiskManager(
@@ -51,10 +49,19 @@ async def run() -> None:
max_open_positions=config.risk_max_open_positions,
volatility_lookback=config.risk_volatility_lookback,
volatility_scale=config.risk_volatility_scale,
+ max_portfolio_exposure=config.risk_max_portfolio_exposure,
+ max_correlated_exposure=config.risk_max_correlated_exposure,
+ correlation_threshold=config.risk_correlation_threshold,
+ var_confidence=config.risk_var_confidence,
+ var_limit_pct=config.risk_var_limit_pct,
+ drawdown_reduction_threshold=config.risk_drawdown_reduction_threshold,
+ drawdown_halt_threshold=config.risk_drawdown_halt_threshold,
+ max_consecutive_losses=config.risk_max_consecutive_losses,
+ loss_pause_minutes=config.risk_loss_pause_minutes,
)
executor = OrderExecutor(
- exchange=exchange,
+ exchange=alpaca,
risk_manager=risk_manager,
broker=broker,
db=db,
@@ -62,41 +69,34 @@ async def run() -> None:
dry_run=config.dry_run,
)
- GROUP = "order-executor"
- CONSUMER = "executor-1"
- stream = "signals"
-
health = HealthCheckServer(
"order-executor",
port=config.health_port + HEALTH_PORT_OFFSET,
auth_token=config.metrics_auth_token,
)
- health.register_check("redis", broker.ping)
await health.start()
metrics.service_up.labels(service="order-executor").set(1)
- log.info("service_started", stream=stream, dry_run=config.dry_run)
+ GROUP = "order-executor"
+ CONSUMER = "executor-1"
+ stream = "signals"
await broker.ensure_group(stream, GROUP)
- # Process pending messages first (from previous crash)
- pending = await broker.read_pending(stream, GROUP, CONSUMER)
- for msg_id, msg in pending:
- try:
- event = Event.from_dict(msg)
- if event.type == EventType.SIGNAL:
- signal = event.data
- log.info(
- "processing_pending_signal", signal_id=str(signal.id), symbol=signal.symbol
- )
- await executor.execute(signal)
- metrics.events_processed.labels(service="order-executor", event_type="signal").inc()
- await broker.ack(stream, GROUP, msg_id)
- except Exception as exc:
- log.error("pending_process_failed", error=str(exc), msg_id=msg_id)
- metrics.errors_total.labels(service="order-executor", error_type="processing").inc()
+ log.info("started", stream=stream, dry_run=config.dry_run)
try:
+ # Process pending messages first
+ pending = await broker.read_pending(stream, GROUP, CONSUMER)
+ for msg_id, msg in pending:
+ try:
+ event = Event.from_dict(msg)
+ if event.type == EventType.SIGNAL:
+ await executor.execute(event.data)
+ await broker.ack(stream, GROUP, msg_id)
+ except Exception as exc:
+ log.error("pending_failed", error=str(exc), msg_id=msg_id)
+
while True:
messages = await broker.read_group(stream, GROUP, CONSUMER, count=10, block=5000)
for msg_id, msg in messages:
@@ -104,29 +104,23 @@ async def run() -> None:
event = Event.from_dict(msg)
if event.type == EventType.SIGNAL:
signal = event.data
- log.info(
- "processing_signal", signal_id=str(signal.id), symbol=signal.symbol
- )
+ log.info("processing_signal", signal_id=signal.id, symbol=signal.symbol)
await executor.execute(signal)
metrics.events_processed.labels(
service="order-executor", event_type="signal"
).inc()
await broker.ack(stream, GROUP, msg_id)
except Exception as exc:
- log.error("message_processing_failed", error=str(exc), msg_id=msg_id)
+ log.error("process_failed", error=str(exc))
metrics.errors_total.labels(
service="order-executor", error_type="processing"
).inc()
- except Exception as exc:
- log.error("fatal_error", error=str(exc))
- await notifier.send_error(str(exc), "order-executor")
- raise
finally:
metrics.service_up.labels(service="order-executor").set(0)
await notifier.close()
await broker.close()
await db.close()
- await exchange.close()
+ await alpaca.close()
def main() -> None:
diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py
index c3578a7..5a05746 100644
--- a/services/order-executor/src/order_executor/risk_manager.py
+++ b/services/order-executor/src/order_executor/risk_manager.py
@@ -1,6 +1,7 @@
"""Risk management for order execution."""
from dataclasses import dataclass
+from datetime import datetime, timezone, timedelta
from decimal import Decimal
from collections import deque
import math
@@ -46,6 +47,15 @@ class RiskManager:
max_open_positions: int = 10,
volatility_lookback: int = 20,
volatility_scale: bool = False,
+ max_portfolio_exposure: float = 0.8,
+ max_correlated_exposure: float = 0.5,
+ correlation_threshold: float = 0.7,
+ var_confidence: float = 0.95,
+ var_limit_pct: float = 5.0,
+ drawdown_reduction_threshold: float = 0.1, # Start reducing at 10% drawdown
+ drawdown_halt_threshold: float = 0.2, # Halt trading at 20% drawdown
+ max_consecutive_losses: int = 5, # Pause after 5 consecutive losses
+ loss_pause_minutes: int = 60, # Pause for 60 minutes after consecutive losses
) -> None:
self.max_position_size = max_position_size
self.stop_loss_pct = stop_loss_pct
@@ -57,6 +67,75 @@ class RiskManager:
self._trailing_stops: dict[str, TrailingStop] = {}
self._price_history: dict[str, deque[float]] = {}
+ self._return_history: dict[str, list[float]] = {}
+ self._max_portfolio_exposure = Decimal(str(max_portfolio_exposure))
+ self._max_correlated_exposure = Decimal(str(max_correlated_exposure))
+ self._correlation_threshold = correlation_threshold
+ self._var_confidence = var_confidence
+ self._var_limit_pct = Decimal(str(var_limit_pct))
+
+ self._drawdown_reduction_threshold = drawdown_reduction_threshold
+ self._drawdown_halt_threshold = drawdown_halt_threshold
+ self._max_consecutive_losses = max_consecutive_losses
+ self._loss_pause_minutes = loss_pause_minutes
+
+ self._peak_balance: Decimal = Decimal("0")
+ self._consecutive_losses: int = 0
+ self._paused_until: datetime | None = None
+
+ def update_balance(self, current_balance: Decimal) -> None:
+ """Track peak balance for drawdown calculation."""
+ if current_balance > self._peak_balance:
+ self._peak_balance = current_balance
+
+ def get_current_drawdown(self, current_balance: Decimal) -> float:
+ """Calculate current drawdown from peak as a fraction (0.0 to 1.0)."""
+ if self._peak_balance <= 0:
+ return 0.0
+ dd = float((self._peak_balance - current_balance) / self._peak_balance)
+ return max(dd, 0.0)
+
+ def get_position_scale(self, current_balance: Decimal) -> float:
+ """Get position size multiplier based on current drawdown.
+
+ Returns 1.0 (full size) when no drawdown.
+ Linearly reduces to 0.25 between reduction threshold and halt threshold.
+ Returns 0.0 at or beyond halt threshold.
+ """
+ dd = self.get_current_drawdown(current_balance)
+
+ if dd >= self._drawdown_halt_threshold:
+ return 0.0
+
+ if dd >= self._drawdown_reduction_threshold:
+ # Linear interpolation from 1.0 to 0.25
+ range_pct = (dd - self._drawdown_reduction_threshold) / (
+ self._drawdown_halt_threshold - self._drawdown_reduction_threshold
+ )
+ return max(1.0 - 0.75 * range_pct, 0.25)
+
+ return 1.0
+
+ def record_trade_result(self, is_win: bool) -> None:
+ """Record a trade result for consecutive loss tracking."""
+ if is_win:
+ self._consecutive_losses = 0
+ else:
+ self._consecutive_losses += 1
+ if self._consecutive_losses >= self._max_consecutive_losses:
+ self._paused_until = datetime.now(timezone.utc) + timedelta(
+ minutes=self._loss_pause_minutes
+ )
+
+ def is_paused(self) -> bool:
+ """Check if trading is paused due to consecutive losses."""
+ if self._paused_until is None:
+ return False
+ if datetime.now(timezone.utc) >= self._paused_until:
+ self._paused_until = None
+ self._consecutive_losses = 0
+ return False
+ return True
def update_price(self, symbol: str, price: Decimal) -> None:
"""Update price tracking for trailing stops and volatility."""
@@ -120,6 +199,145 @@ class RiskManager:
scale = min(target_vol / vol, 2.0) # Cap at 2x
return base_size * Decimal(str(scale))
+ def calculate_correlation(self, symbol_a: str, symbol_b: str) -> float | None:
+ """Calculate Pearson correlation between two symbols' returns."""
+ hist_a = self._price_history.get(symbol_a)
+ hist_b = self._price_history.get(symbol_b)
+ if not hist_a or not hist_b or len(hist_a) < 5 or len(hist_b) < 5:
+ return None
+
+ prices_a = list(hist_a)
+ prices_b = list(hist_b)
+ min_len = min(len(prices_a), len(prices_b))
+ prices_a = prices_a[-min_len:]
+ prices_b = prices_b[-min_len:]
+
+ returns_a = [
+ (prices_a[i] - prices_a[i - 1]) / prices_a[i - 1]
+ for i in range(1, len(prices_a))
+ if prices_a[i - 1] != 0
+ ]
+ returns_b = [
+ (prices_b[i] - prices_b[i - 1]) / prices_b[i - 1]
+ for i in range(1, len(prices_b))
+ if prices_b[i - 1] != 0
+ ]
+
+ if len(returns_a) < 3 or len(returns_b) < 3:
+ return None
+
+ min_len = min(len(returns_a), len(returns_b))
+ returns_a = returns_a[-min_len:]
+ returns_b = returns_b[-min_len:]
+
+ mean_a = sum(returns_a) / len(returns_a)
+ mean_b = sum(returns_b) / len(returns_b)
+
+ cov = sum((a - mean_a) * (b - mean_b) for a, b in zip(returns_a, returns_b)) / len(
+ returns_a
+ )
+ std_a = math.sqrt(sum((a - mean_a) ** 2 for a in returns_a) / len(returns_a))
+ std_b = math.sqrt(sum((b - mean_b) ** 2 for b in returns_b) / len(returns_b))
+
+ if std_a == 0 or std_b == 0:
+ return None
+
+ return cov / (std_a * std_b)
+
+ def calculate_portfolio_var(self, positions: dict[str, Position], balance: Decimal) -> float:
+ """Calculate portfolio VaR using historical simulation.
+
+ Returns VaR as a percentage of balance (e.g., 3.5 for 3.5%).
+ """
+ if not positions or balance <= 0:
+ return 0.0
+
+ # Collect returns for all positioned symbols
+ all_returns: list[list[float]] = []
+ weights: list[float] = []
+
+ for symbol, pos in positions.items():
+ if pos.quantity <= 0:
+ continue
+ hist = self._price_history.get(symbol)
+ if not hist or len(hist) < 5:
+ continue
+ prices = list(hist)
+ returns = [
+ (prices[i] - prices[i - 1]) / prices[i - 1]
+ for i in range(1, len(prices))
+ if prices[i - 1] != 0
+ ]
+ if returns:
+ all_returns.append(returns)
+ weight = float(pos.quantity * pos.current_price / balance)
+ weights.append(weight)
+
+ if not all_returns:
+ return 0.0
+
+ # Portfolio returns (weighted sum)
+ min_len = min(len(r) for r in all_returns)
+ portfolio_returns = []
+ for i in range(min_len):
+ pr = sum(w * r[-(min_len - i)] for w, r in zip(weights, all_returns) if len(r) > i)
+ portfolio_returns.append(pr)
+
+ if not portfolio_returns:
+ return 0.0
+
+ # Historical VaR: sort returns, take the (1-confidence) percentile
+ sorted_returns = sorted(portfolio_returns)
+ index = int((1 - self._var_confidence) * len(sorted_returns))
+ index = max(0, min(index, len(sorted_returns) - 1))
+ var_return = sorted_returns[index]
+
+ return abs(var_return) * 100 # As percentage
+
+ def check_portfolio_exposure(
+ self, positions: dict[str, Position], balance: Decimal
+ ) -> RiskCheckResult:
+ """Check total portfolio exposure."""
+ if balance <= 0:
+ return RiskCheckResult(allowed=True, reason="OK")
+
+ total_exposure = sum(
+ pos.quantity * pos.current_price for pos in positions.values() if pos.quantity > 0
+ )
+
+ exposure_ratio = total_exposure / balance
+ if exposure_ratio > self._max_portfolio_exposure:
+ return RiskCheckResult(
+ allowed=False,
+ reason=f"Portfolio exposure {float(exposure_ratio):.1%} exceeds max {float(self._max_portfolio_exposure):.1%}",
+ )
+
+ return RiskCheckResult(allowed=True, reason="OK")
+
+ def check_correlation_risk(
+ self, signal: Signal, positions: dict[str, Position], balance: Decimal
+ ) -> RiskCheckResult:
+ """Check if adding this position creates too much correlated exposure."""
+ if signal.side != OrderSide.BUY or balance <= 0:
+ return RiskCheckResult(allowed=True, reason="OK")
+
+ correlated_value = signal.price * signal.quantity
+
+ for symbol, pos in positions.items():
+ if pos.quantity <= 0 or symbol == signal.symbol:
+ continue
+ corr = self.calculate_correlation(signal.symbol, symbol)
+ if corr is not None and abs(corr) >= self._correlation_threshold:
+ correlated_value += pos.quantity * pos.current_price
+
+ if correlated_value / balance > self._max_correlated_exposure:
+ return RiskCheckResult(
+ allowed=False,
+ reason=f"Correlated exposure would exceed {float(self._max_correlated_exposure):.1%}",
+ )
+
+ return RiskCheckResult(allowed=True, reason="OK")
+
def check(
self,
signal: Signal,
@@ -128,6 +346,21 @@ class RiskManager:
daily_pnl: Decimal,
) -> RiskCheckResult:
"""Run risk checks against a signal and current portfolio state."""
+ # Check if paused due to consecutive losses
+ if self.is_paused():
+ return RiskCheckResult(
+ allowed=False,
+ reason=f"Trading paused until {self._paused_until.isoformat()} after {self._max_consecutive_losses} consecutive losses",
+ )
+
+ # Check drawdown halt
+ dd = self.get_current_drawdown(balance)
+ if dd >= self._drawdown_halt_threshold:
+ return RiskCheckResult(
+ allowed=False,
+ reason=f"Trading halted: drawdown {dd:.1%} exceeds {self._drawdown_halt_threshold:.1%}",
+ )
+
# Check daily loss limit
if balance > 0 and (daily_pnl / balance) * 100 < -self.daily_loss_limit_pct:
return RiskCheckResult(allowed=False, reason="Daily loss limit exceeded")
@@ -165,4 +398,22 @@ class RiskManager:
):
return RiskCheckResult(allowed=False, reason="Position size exceeded")
+ # Portfolio-level checks
+ exposure_check = self.check_portfolio_exposure(positions, balance)
+ if not exposure_check.allowed:
+ return exposure_check
+
+ corr_check = self.check_correlation_risk(signal, positions, balance)
+ if not corr_check.allowed:
+ return corr_check
+
+ # VaR check
+ if positions:
+ var = self.calculate_portfolio_var(positions, balance)
+ if var > float(self._var_limit_pct):
+ return RiskCheckResult(
+ allowed=False,
+ reason=f"Portfolio VaR {var:.1f}% exceeds limit {float(self._var_limit_pct):.1f}%",
+ )
+
return RiskCheckResult(allowed=True, reason="OK")
diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py
index e64b6c0..dd823d7 100644
--- a/services/order-executor/tests/test_executor.py
+++ b/services/order-executor/tests/test_executor.py
@@ -13,7 +13,7 @@ from order_executor.risk_manager import RiskCheckResult, RiskManager
def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: str = "1") -> Signal:
return Signal(
strategy="test",
- symbol="BTC/USDT",
+ symbol="AAPL",
side=side,
price=Decimal(price),
quantity=Decimal(quantity),
@@ -21,10 +21,10 @@ def make_signal(side: OrderSide = OrderSide.BUY, price: str = "100", quantity: s
)
-def make_mock_exchange(free_usdt: float = 10000.0) -> AsyncMock:
+def make_mock_exchange(buying_power: str = "10000") -> AsyncMock:
exchange = AsyncMock()
- exchange.fetch_balance.return_value = {"free": {"USDT": free_usdt}}
- exchange.create_order = AsyncMock(return_value={"id": "exchange-order-123"})
+ exchange.get_buying_power = AsyncMock(return_value=Decimal(buying_power))
+ exchange.submit_order = AsyncMock(return_value={"id": "alpaca-order-123"})
return exchange
@@ -48,7 +48,7 @@ def make_mock_db() -> AsyncMock:
@pytest.mark.asyncio
async def test_executor_places_order_when_risk_passes():
- """When risk check passes, create_order is called and order status is FILLED."""
+ """When risk check passes, submit_order is called and order status is FILLED."""
exchange = make_mock_exchange()
risk_manager = make_mock_risk_manager(allowed=True)
broker = make_mock_broker()
@@ -68,14 +68,14 @@ async def test_executor_places_order_when_risk_passes():
assert order is not None
assert order.status == OrderStatus.FILLED
- exchange.create_order.assert_called_once()
+ exchange.submit_order.assert_called_once()
db.insert_order.assert_called_once_with(order)
broker.publish.assert_called_once()
@pytest.mark.asyncio
async def test_executor_rejects_when_risk_fails():
- """When risk check fails, create_order is not called and None is returned."""
+ """When risk check fails, submit_order is not called and None is returned."""
exchange = make_mock_exchange()
risk_manager = make_mock_risk_manager(allowed=False, reason="Position size exceeded")
broker = make_mock_broker()
@@ -94,14 +94,14 @@ async def test_executor_rejects_when_risk_fails():
order = await executor.execute(signal)
assert order is None
- exchange.create_order.assert_not_called()
+ exchange.submit_order.assert_not_called()
db.insert_order.assert_not_called()
broker.publish.assert_not_called()
@pytest.mark.asyncio
async def test_executor_dry_run_does_not_call_exchange():
- """In dry-run mode, risk passes, order is FILLED, but exchange.create_order is NOT called."""
+ """In dry-run mode, risk passes, order is FILLED, but exchange.submit_order is NOT called."""
exchange = make_mock_exchange()
risk_manager = make_mock_risk_manager(allowed=True)
broker = make_mock_broker()
@@ -121,6 +121,6 @@ async def test_executor_dry_run_does_not_call_exchange():
assert order is not None
assert order.status == OrderStatus.FILLED
- exchange.create_order.assert_not_called()
+ exchange.submit_order.assert_not_called()
db.insert_order.assert_called_once_with(order)
broker.publish.assert_called_once()
diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py
index efabe73..3d5175b 100644
--- a/services/order-executor/tests/test_risk_manager.py
+++ b/services/order-executor/tests/test_risk_manager.py
@@ -7,7 +7,7 @@ from shared.models import OrderSide, Position, Signal
from order_executor.risk_manager import RiskManager
-def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "BTC/USDT") -> Signal:
+def make_signal(side: OrderSide, price: str, quantity: str, symbol: str = "AAPL") -> Signal:
return Signal(
strategy="test",
symbol=symbol,
@@ -93,7 +93,7 @@ def test_risk_check_rejects_insufficient_balance():
def test_trailing_stop_set_and_trigger():
"""Trailing stop should trigger when price drops below stop level."""
rm = make_risk_manager(trailing_stop_pct="5")
- rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+ rm.set_trailing_stop("AAPL", Decimal("100"))
signal = make_signal(side=OrderSide.BUY, price="94", quantity="0.01")
result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("0"))
@@ -104,10 +104,10 @@ def test_trailing_stop_set_and_trigger():
def test_trailing_stop_updates_highest_price():
"""Trailing stop should track the highest price seen."""
rm = make_risk_manager(trailing_stop_pct="5")
- rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+ rm.set_trailing_stop("AAPL", Decimal("100"))
# Price rises to 120 => stop at 114
- rm.update_price("BTC/USDT", Decimal("120"))
+ rm.update_price("AAPL", Decimal("120"))
# Price at 115 is above stop (114), should be allowed
signal = make_signal(side=OrderSide.BUY, price="115", quantity="0.01")
@@ -124,7 +124,7 @@ def test_trailing_stop_updates_highest_price():
def test_trailing_stop_not_triggered_above_stop():
"""Trailing stop should not trigger when price is above stop level."""
rm = make_risk_manager(trailing_stop_pct="5")
- rm.set_trailing_stop("BTC/USDT", Decimal("100"))
+ rm.set_trailing_stop("AAPL", Decimal("100"))
# Price at 96 is above stop (95), should be allowed
signal = make_signal(side=OrderSide.BUY, price="96", quantity="0.01")
@@ -140,11 +140,11 @@ def test_max_open_positions_check():
rm = make_risk_manager(max_open_positions=2)
positions = {
- "BTC/USDT": make_position("BTC/USDT", "1", "100", "100"),
- "ETH/USDT": make_position("ETH/USDT", "10", "50", "50"),
+ "AAPL": make_position("AAPL", "1", "100", "100"),
+ "MSFT": make_position("MSFT", "10", "50", "50"),
}
- signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="SOL/USDT")
+ signal = make_signal(side=OrderSide.BUY, price="10", quantity="1", symbol="TSLA")
result = rm.check(signal, balance=Decimal("10000"), positions=positions, daily_pnl=Decimal("0"))
assert result.allowed is False
assert result.reason == "Max open positions reached"
@@ -158,14 +158,14 @@ def test_volatility_calculation():
rm = make_risk_manager(volatility_lookback=5)
# No history yet
- assert rm.get_volatility("BTC/USDT") is None
+ assert rm.get_volatility("AAPL") is None
# Feed prices
prices = [100, 102, 98, 105, 101]
for p in prices:
- rm.update_price("BTC/USDT", Decimal(str(p)))
+ rm.update_price("AAPL", Decimal(str(p)))
- vol = rm.get_volatility("BTC/USDT")
+ vol = rm.get_volatility("AAPL")
assert vol is not None
assert vol > 0
@@ -177,9 +177,9 @@ def test_position_size_with_volatility_scaling():
# Feed volatile prices
prices = [100, 120, 80, 130, 70]
for p in prices:
- rm.update_price("BTC/USDT", Decimal(str(p)))
+ rm.update_price("AAPL", Decimal(str(p)))
- size = rm.calculate_position_size("BTC/USDT", Decimal("10000"))
+ size = rm.calculate_position_size("AAPL", Decimal("10000"))
base = Decimal("10000") * Decimal("0.1")
# High volatility should reduce size below base
@@ -192,9 +192,177 @@ def test_position_size_without_scaling():
prices = [100, 120, 80, 130, 70]
for p in prices:
- rm.update_price("BTC/USDT", Decimal(str(p)))
+ rm.update_price("AAPL", Decimal(str(p)))
- size = rm.calculate_position_size("BTC/USDT", Decimal("10000"))
+ size = rm.calculate_position_size("AAPL", Decimal("10000"))
base = Decimal("10000") * Decimal("0.1")
assert size == base
+
+
+# --- Portfolio exposure tests ---
+
+
+def test_portfolio_exposure_check_passes():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ max_portfolio_exposure=0.8,
+ )
+ positions = {
+ "AAPL": Position(
+ symbol="AAPL",
+ quantity=Decimal("0.01"),
+ avg_entry_price=Decimal("50000"),
+ current_price=Decimal("50000"),
+ )
+ }
+ result = rm.check_portfolio_exposure(positions, Decimal("10000"))
+ assert result.allowed # 500/10000 = 5% < 80%
+
+
+def test_portfolio_exposure_check_rejects():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ max_portfolio_exposure=0.3,
+ )
+ positions = {
+ "AAPL": Position(
+ symbol="AAPL",
+ quantity=Decimal("1"),
+ avg_entry_price=Decimal("50000"),
+ current_price=Decimal("50000"),
+ )
+ }
+ result = rm.check_portfolio_exposure(positions, Decimal("10000"))
+ assert not result.allowed # 50000/10000 = 500% > 30%
+
+
+def test_correlation_calculation():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ )
+ # Feed identical price histories — correlation should be ~1.0
+ for i in range(20):
+ rm.update_price("A", Decimal(str(100 + i)))
+ rm.update_price("B", Decimal(str(100 + i)))
+ corr = rm.calculate_correlation("A", "B")
+ assert corr is not None
+ assert corr > 0.9
+
+
+def test_var_calculation():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ )
+ for i in range(30):
+ rm.update_price("AAPL", Decimal(str(100 + (i % 5) - 2)))
+ positions = {
+ "AAPL": Position(
+ symbol="AAPL",
+ quantity=Decimal("1"),
+ avg_entry_price=Decimal("100"),
+ current_price=Decimal("100"),
+ )
+ }
+ var = rm.calculate_portfolio_var(positions, Decimal("10000"))
+ assert var >= 0 # Non-negative
+
+
+# --- Drawdown position scaling tests ---
+
+
+def test_drawdown_position_scale_full():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ drawdown_reduction_threshold=0.1,
+ drawdown_halt_threshold=0.2,
+ )
+ rm.update_balance(Decimal("10000"))
+ scale = rm.get_position_scale(Decimal("10000"))
+ assert scale == 1.0 # No drawdown
+
+
+def test_drawdown_position_scale_reduced():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ drawdown_reduction_threshold=0.1,
+ drawdown_halt_threshold=0.2,
+ )
+ rm.update_balance(Decimal("10000"))
+ scale = rm.get_position_scale(Decimal("8500")) # 15% drawdown (between 10% and 20%)
+ assert 0.25 < scale < 1.0
+
+
+def test_drawdown_halt():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ drawdown_reduction_threshold=0.1,
+ drawdown_halt_threshold=0.2,
+ )
+ rm.update_balance(Decimal("10000"))
+ scale = rm.get_position_scale(Decimal("7500")) # 25% drawdown
+ assert scale == 0.0
+
+
+def test_consecutive_losses_pause():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ max_consecutive_losses=3,
+ loss_pause_minutes=60,
+ )
+ rm.record_trade_result(False)
+ rm.record_trade_result(False)
+ assert not rm.is_paused()
+ rm.record_trade_result(False) # 3rd loss
+ assert rm.is_paused()
+
+
+def test_consecutive_losses_reset_on_win():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ max_consecutive_losses=3,
+ )
+ rm.record_trade_result(False)
+ rm.record_trade_result(False)
+ rm.record_trade_result(True) # Win resets counter
+ rm.record_trade_result(False)
+ assert not rm.is_paused() # Only 1 consecutive loss
+
+
+def test_drawdown_check_rejects_in_check():
+ rm = RiskManager(
+ max_position_size=Decimal("0.5"),
+ stop_loss_pct=Decimal("5"),
+ daily_loss_limit_pct=Decimal("10"),
+ drawdown_halt_threshold=0.15,
+ )
+ rm.update_balance(Decimal("10000"))
+ signal = Signal(
+ strategy="test",
+ symbol="AAPL",
+ side=OrderSide.BUY,
+ price=Decimal("50000"),
+ quantity=Decimal("0.01"),
+ reason="test",
+ )
+ result = rm.check(signal, Decimal("8000"), {}, Decimal("0")) # 20% dd > 15%
+ assert not result.allowed
+ assert "halted" in result.reason.lower()