"""Portfolio tracking for the portfolio manager service.""" from decimal import Decimal from shared.models import Order, OrderSide, Position class _PositionState: """Internal state for tracking a single symbol's position.""" def __init__(self) -> None: self.quantity: Decimal = Decimal("0") self.avg_entry: Decimal = Decimal("0") class PortfolioTracker: """Tracks positions and updates them based on filled orders.""" def __init__(self) -> None: self._positions: dict[str, _PositionState] = {} self._realized_pnl: Decimal = Decimal("0") @property def realized_pnl(self) -> Decimal: return self._realized_pnl def _get_or_create(self, symbol: str) -> _PositionState: if symbol not in self._positions: self._positions[symbol] = _PositionState() return self._positions[symbol] def apply_order(self, order: Order) -> None: """Update internal position state based on a filled order.""" state = self._get_or_create(order.symbol) if order.side == OrderSide.BUY: # Weighted average entry price total_cost = state.avg_entry * state.quantity + order.price * order.quantity state.quantity += order.quantity if state.quantity > Decimal("0"): state.avg_entry = total_cost / state.quantity elif order.side == OrderSide.SELL: # Calculate realized PnL for this sell sell_quantity = min(order.quantity, state.quantity) if sell_quantity > 0 and state.avg_entry > 0: self._realized_pnl += sell_quantity * (order.price - state.avg_entry) state.quantity -= sell_quantity if state.quantity <= Decimal("0"): state.quantity = Decimal("0") state.avg_entry = Decimal("0") def get_position(self, symbol: str) -> Position | None: """Return a Position model for the symbol, or None if no/zero position.""" state = self._positions.get(symbol) if state is None or state.quantity <= Decimal("0"): return None return Position( symbol=symbol, quantity=state.quantity, avg_entry_price=state.avg_entry, current_price=state.avg_entry, # No live price here; caller can update ) def get_all_positions(self) -> list[Position]: """Return all non-zero positions.""" positions = [] for symbol in self._positions: pos = self.get_position(symbol) if pos is not None: positions.append(pos) return positions