diff options
Diffstat (limited to 'shared/src')
| -rw-r--r-- | shared/src/shared/db.py | 3 | ||||
| -rw-r--r-- | shared/src/shared/events.py | 16 | ||||
| -rw-r--r-- | shared/src/shared/sa_models.py | 16 |
3 files changed, 33 insertions, 2 deletions
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index a718951..8fee000 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -112,6 +112,9 @@ class Database: price=signal.price, quantity=signal.quantity, reason=signal.reason, + conviction=signal.conviction, + stop_loss=signal.stop_loss, + take_profit=signal.take_profit, created_at=signal.created_at, ) async with self._session_factory() as session: diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py index 6f8def1..61b85bd 100644 --- a/shared/src/shared/events.py +++ b/shared/src/shared/events.py @@ -88,6 +88,18 @@ class Event: @staticmethod def from_dict(data: dict) -> Any: - event_type = EventType(data["type"]) + """Deserialize a raw dict into the appropriate event type. + + Raises ValueError for malformed or unrecognized event data. + """ + try: + event_type = EventType(data["type"]) + except (KeyError, ValueError) as exc: + raise ValueError(f"Invalid or missing event type in data: {data!r}") from exc cls = _EVENT_TYPE_MAP[event_type] - return cls.from_raw(data) + try: + return cls.from_raw(data) + except KeyError as exc: + raise ValueError( + f"Missing required field in {event_type} event data: {exc}" + ) from exc diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py index dc87ef5..b70a6c4 100644 --- a/shared/src/shared/sa_models.py +++ b/shared/src/shared/sa_models.py @@ -35,6 +35,9 @@ class SignalRow(Base): price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) reason: Mapped[str | None] = mapped_column(Text) + conviction: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default="1.0") + stop_loss: Mapped[Decimal | None] = mapped_column(Numeric) + take_profit: Mapped[Decimal | None] = mapped_column(Numeric) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) @@ -53,6 +56,19 @@ class OrderRow(Base): filled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) +class TradeRow(Base): + __tablename__ = "trades" + + id: Mapped[str] = mapped_column(Text, primary_key=True) + order_id: Mapped[str | None] = mapped_column(Text, ForeignKey("orders.id")) + symbol: Mapped[str] = mapped_column(Text, nullable=False) + side: Mapped[str] = mapped_column(Text, nullable=False) + price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) + fee: Mapped[Decimal] = mapped_column(Numeric, nullable=False, server_default="0") + traded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + class PositionRow(Base): __tablename__ = "positions" |
