diff options
Diffstat (limited to 'services/portfolio-manager/src/portfolio_manager/portfolio.py')
| -rw-r--r-- | services/portfolio-manager/src/portfolio_manager/portfolio.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/services/portfolio-manager/src/portfolio_manager/portfolio.py b/services/portfolio-manager/src/portfolio_manager/portfolio.py new file mode 100644 index 0000000..59106bb --- /dev/null +++ b/services/portfolio-manager/src/portfolio_manager/portfolio.py @@ -0,0 +1,62 @@ +"""Portfolio tracking for the portfolio manager service.""" +from decimal import Decimal + +from shared.models import Order, OrderSide, Position + + +class _PositionState: + """Internal state for tracking a single symbol's position.""" + + def __init__(self) -> None: + self.quantity: Decimal = Decimal("0") + self.avg_entry: Decimal = Decimal("0") + + +class PortfolioTracker: + """Tracks positions and updates them based on filled orders.""" + + def __init__(self) -> None: + self._positions: dict[str, _PositionState] = {} + + def _get_or_create(self, symbol: str) -> _PositionState: + if symbol not in self._positions: + self._positions[symbol] = _PositionState() + return self._positions[symbol] + + def apply_order(self, order: Order) -> None: + """Update internal position state based on a filled order.""" + state = self._get_or_create(order.symbol) + + if order.side == OrderSide.BUY: + # Weighted average entry price + total_cost = state.avg_entry * state.quantity + order.price * order.quantity + state.quantity += order.quantity + if state.quantity > Decimal("0"): + state.avg_entry = total_cost / state.quantity + elif order.side == OrderSide.SELL: + state.quantity -= order.quantity + # Keep avg_entry unchanged unless fully sold + if state.quantity <= Decimal("0"): + state.quantity = Decimal("0") + state.avg_entry = Decimal("0") + + def get_position(self, symbol: str) -> Position | None: + """Return a Position model for the symbol, or None if no/zero position.""" + state = self._positions.get(symbol) + if state is None or state.quantity <= Decimal("0"): + return None + return Position( + symbol=symbol, + quantity=state.quantity, + avg_entry_price=state.avg_entry, + current_price=state.avg_entry, # No live price here; caller can update + ) + + def get_all_positions(self) -> list[Position]: + """Return all non-zero positions.""" + positions = [] + for symbol in self._positions: + pos = self.get_position(symbol) + if pos is not None: + positions.append(pos) + return positions |
