blob: 4a7a11f4cd5d7f1e5fc2cdc87e5d0cdf8bd28e82 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
"""Portfolio tracking for the portfolio manager service."""
from decimal import Decimal
from shared.models import Order, OrderSide, Position
class _PositionState:
"""Internal state for tracking a single symbol's position."""
def __init__(self) -> None:
self.quantity: Decimal = Decimal("0")
self.avg_entry: Decimal = Decimal("0")
class PortfolioTracker:
"""Tracks positions and updates them based on filled orders."""
def __init__(self) -> None:
self._positions: dict[str, _PositionState] = {}
self._realized_pnl: Decimal = Decimal("0")
@property
def realized_pnl(self) -> Decimal:
return self._realized_pnl
def _get_or_create(self, symbol: str) -> _PositionState:
if symbol not in self._positions:
self._positions[symbol] = _PositionState()
return self._positions[symbol]
def apply_order(self, order: Order) -> None:
"""Update internal position state based on a filled order."""
state = self._get_or_create(order.symbol)
if order.side == OrderSide.BUY:
# Weighted average entry price
total_cost = state.avg_entry * state.quantity + order.price * order.quantity
state.quantity += order.quantity
if state.quantity > Decimal("0"):
state.avg_entry = total_cost / state.quantity
elif order.side == OrderSide.SELL:
# Calculate realized PnL for this sell
sell_quantity = min(order.quantity, state.quantity)
if sell_quantity > 0 and state.avg_entry > 0:
self._realized_pnl += sell_quantity * (order.price - state.avg_entry)
state.quantity -= sell_quantity
if state.quantity <= Decimal("0"):
state.quantity = Decimal("0")
state.avg_entry = Decimal("0")
def get_position(self, symbol: str) -> Position | None:
"""Return a Position model for the symbol, or None if no/zero position."""
state = self._positions.get(symbol)
if state is None or state.quantity <= Decimal("0"):
return None
return Position(
symbol=symbol,
quantity=state.quantity,
avg_entry_price=state.avg_entry,
current_price=state.avg_entry, # No live price here; caller can update
)
def get_all_positions(self) -> list[Position]:
"""Return all non-zero positions."""
positions = []
for symbol in self._positions:
pos = self.get_position(symbol)
if pos is not None:
positions.append(pos)
return positions
|