summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/pyproject.toml25
-rw-r--r--shared/src/shared/__init__.py1
-rw-r--r--shared/src/shared/broker.py43
-rw-r--r--shared/src/shared/config.py16
-rw-r--r--shared/src/shared/db.py184
-rw-r--r--shared/src/shared/events.py75
-rw-r--r--shared/src/shared/models.py72
-rw-r--r--shared/tests/__init__.py0
-rw-r--r--shared/tests/test_broker.py66
-rw-r--r--shared/tests/test_db.py70
-rw-r--r--shared/tests/test_events.py80
-rw-r--r--shared/tests/test_models.py100
12 files changed, 732 insertions, 0 deletions
diff --git a/shared/pyproject.toml b/shared/pyproject.toml
new file mode 100644
index 0000000..bd09d3e
--- /dev/null
+++ b/shared/pyproject.toml
@@ -0,0 +1,25 @@
+[project]
+name = "trading-shared"
+version = "0.1.0"
+description = "Shared models, events, and utilities for trading platform"
+requires-python = ">=3.12"
+dependencies = [
+ "pydantic>=2.0",
+ "pydantic-settings>=2.0",
+ "redis>=5.0",
+ "asyncpg>=0.29",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=8.0",
+ "pytest-asyncio>=0.23",
+ "ruff>=0.4",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/shared"]
diff --git a/shared/src/shared/__init__.py b/shared/src/shared/__init__.py
new file mode 100644
index 0000000..d2ee024
--- /dev/null
+++ b/shared/src/shared/__init__.py
@@ -0,0 +1 @@
+"""Shared library for the trading platform."""
diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py
new file mode 100644
index 0000000..9a50441
--- /dev/null
+++ b/shared/src/shared/broker.py
@@ -0,0 +1,43 @@
+"""Redis Streams broker for the trading platform."""
+import json
+from typing import Any
+
+import redis.asyncio
+
+
+class RedisBroker:
+ """Async Redis Streams broker for publishing and reading events."""
+
+ def __init__(self, redis_url: str) -> None:
+ self._redis = redis.asyncio.from_url(redis_url)
+
+ async def publish(self, stream: str, data: dict[str, Any]) -> None:
+ """Publish a message to a Redis stream."""
+ payload = json.dumps(data)
+ await self._redis.xadd(stream, {"payload": payload})
+
+ async def read(
+ self,
+ stream: str,
+ last_id: str = "$",
+ count: int = 10,
+ block: int = 0,
+ ) -> list[dict[str, Any]]:
+ """Read messages from a Redis stream."""
+ results = await self._redis.xread(
+ {stream: last_id}, count=count, block=block
+ )
+ messages = []
+ if results:
+ for _stream, entries in results:
+ for _msg_id, fields in entries:
+ payload = fields.get(b"payload") or fields.get("payload")
+ if payload:
+ if isinstance(payload, bytes):
+ payload = payload.decode()
+ messages.append(json.loads(payload))
+ return messages
+
+ async def close(self) -> None:
+ """Close the Redis connection."""
+ await self._redis.aclose()
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
new file mode 100644
index 0000000..1304c5e
--- /dev/null
+++ b/shared/src/shared/config.py
@@ -0,0 +1,16 @@
+"""Shared configuration settings for the trading platform."""
+from pydantic_settings import BaseSettings
+
+
+class Settings(BaseSettings):
+ binance_api_key: str
+ binance_api_secret: str
+ redis_url: str = "redis://localhost:6379"
+ database_url: str = "postgresql://trading:trading@localhost:5432/trading"
+ log_level: str = "INFO"
+ risk_max_position_size: float = 0.1
+ risk_stop_loss_pct: float = 5.0
+ risk_daily_loss_limit_pct: float = 10.0
+ dry_run: bool = True
+
+ model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
new file mode 100644
index 0000000..6bddd7c
--- /dev/null
+++ b/shared/src/shared/db.py
@@ -0,0 +1,184 @@
+"""Database layer using asyncpg for the trading platform."""
+from datetime import datetime, timezone
+from typing import Optional
+
+import asyncpg
+
+from shared.models import Candle, Signal, Order, OrderStatus
+
+
+_INIT_SQL = """
+CREATE TABLE IF NOT EXISTS candles (
+ symbol TEXT NOT NULL,
+ timeframe TEXT NOT NULL,
+ open_time TIMESTAMPTZ NOT NULL,
+ open NUMERIC NOT NULL,
+ high NUMERIC NOT NULL,
+ low NUMERIC NOT NULL,
+ close NUMERIC NOT NULL,
+ volume NUMERIC NOT NULL,
+ PRIMARY KEY (symbol, timeframe, open_time)
+);
+
+CREATE TABLE IF NOT EXISTS signals (
+ id TEXT PRIMARY KEY,
+ strategy TEXT NOT NULL,
+ symbol TEXT NOT NULL,
+ side TEXT NOT NULL,
+ price NUMERIC NOT NULL,
+ quantity NUMERIC NOT NULL,
+ reason TEXT,
+ created_at TIMESTAMPTZ NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS orders (
+ id TEXT PRIMARY KEY,
+ signal_id TEXT REFERENCES signals(id),
+ symbol TEXT NOT NULL,
+ side TEXT NOT NULL,
+ type TEXT NOT NULL,
+ price NUMERIC NOT NULL,
+ quantity NUMERIC NOT NULL,
+ status TEXT NOT NULL DEFAULT 'PENDING',
+ created_at TIMESTAMPTZ NOT NULL,
+ filled_at TIMESTAMPTZ
+);
+
+CREATE TABLE IF NOT EXISTS trades (
+ id TEXT PRIMARY KEY,
+ order_id TEXT REFERENCES orders(id),
+ symbol TEXT NOT NULL,
+ side TEXT NOT NULL,
+ price NUMERIC NOT NULL,
+ quantity NUMERIC NOT NULL,
+ fee NUMERIC NOT NULL DEFAULT 0,
+ traded_at TIMESTAMPTZ NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS positions (
+ symbol TEXT PRIMARY KEY,
+ quantity NUMERIC NOT NULL,
+ avg_entry_price NUMERIC NOT NULL,
+ current_price NUMERIC NOT NULL,
+ updated_at TIMESTAMPTZ NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS portfolio_snapshots (
+ id SERIAL PRIMARY KEY,
+ total_value NUMERIC NOT NULL,
+ realized_pnl NUMERIC NOT NULL,
+ unrealized_pnl NUMERIC NOT NULL,
+ snapshot_at TIMESTAMPTZ NOT NULL
+);
+"""
+
+
+class Database:
+ """Async database access layer backed by asyncpg connection pool."""
+
+ def __init__(self, database_url: str) -> None:
+ self._database_url = database_url
+ self._pool: Optional[asyncpg.Pool] = None
+
+ async def connect(self) -> None:
+ """Create the asyncpg connection pool."""
+ self._pool = await asyncpg.create_pool(self._database_url)
+
+ async def close(self) -> None:
+ """Close the asyncpg connection pool."""
+ if self._pool:
+ await self._pool.close()
+ self._pool = None
+
+ async def init_tables(self) -> None:
+ """Create all tables if they do not exist."""
+ async with self._pool as conn:
+ await conn.execute(_INIT_SQL)
+
+ async def insert_candle(self, candle: Candle) -> None:
+ """Insert a candle row, ignoring duplicates."""
+ sql = """
+ INSERT INTO candles (symbol, timeframe, open_time, open, high, low, close, volume)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
+ ON CONFLICT DO NOTHING
+ """
+ async with self._pool as conn:
+ await conn.execute(
+ sql,
+ candle.symbol,
+ candle.timeframe,
+ candle.open_time,
+ candle.open,
+ candle.high,
+ candle.low,
+ candle.close,
+ candle.volume,
+ )
+
+ async def insert_signal(self, signal: Signal) -> None:
+ """Insert a signal row."""
+ sql = """
+ INSERT INTO signals (id, strategy, symbol, side, price, quantity, reason, created_at)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
+ """
+ async with self._pool as conn:
+ await conn.execute(
+ sql,
+ signal.id,
+ signal.strategy,
+ signal.symbol,
+ signal.side.value,
+ signal.price,
+ signal.quantity,
+ signal.reason,
+ signal.created_at,
+ )
+
+ async def insert_order(self, order: Order) -> None:
+ """Insert an order row."""
+ sql = """
+ INSERT INTO orders (id, signal_id, symbol, side, type, price, quantity, status, created_at, filled_at)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
+ """
+ async with self._pool as conn:
+ await conn.execute(
+ sql,
+ order.id,
+ order.signal_id,
+ order.symbol,
+ order.side.value,
+ order.type.value,
+ order.price,
+ order.quantity,
+ order.status.value,
+ order.created_at,
+ order.filled_at,
+ )
+
+ async def update_order_status(
+ self,
+ order_id: str,
+ status: OrderStatus,
+ filled_at: Optional[datetime] = None,
+ ) -> None:
+ """Update the status (and optionally filled_at) of an order."""
+ sql = """
+ UPDATE orders SET status = $2, filled_at = $3 WHERE id = $1
+ """
+ async with self._pool as conn:
+ await conn.execute(sql, order_id, status.value, filled_at)
+
+ async def get_candles(
+ self, symbol: str, timeframe: str, limit: int = 500
+ ) -> list[dict]:
+ """Retrieve candles ordered by open_time descending."""
+ sql = """
+ SELECT symbol, timeframe, open_time, open, high, low, close, volume
+ FROM candles
+ WHERE symbol = $1 AND timeframe = $2
+ ORDER BY open_time DESC
+ LIMIT $3
+ """
+ async with self._pool as conn:
+ rows = await conn.fetch(sql, symbol, timeframe, limit)
+ return [dict(row) for row in rows]
diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py
new file mode 100644
index 0000000..1db2bee
--- /dev/null
+++ b/shared/src/shared/events.py
@@ -0,0 +1,75 @@
+"""Event types and serialization for the trading platform."""
+from enum import Enum
+from typing import Any
+
+from pydantic import BaseModel
+
+from shared.models import Candle, Signal, Order
+
+
+class EventType(str, Enum):
+ CANDLE = "CANDLE"
+ SIGNAL = "SIGNAL"
+ ORDER = "ORDER"
+
+
+class CandleEvent(BaseModel):
+ type: EventType = EventType.CANDLE
+ data: Candle
+
+ def to_dict(self) -> dict:
+ return {
+ "type": self.type,
+ "data": self.data.model_dump(mode="json"),
+ }
+
+ @classmethod
+ def from_raw(cls, raw: dict) -> "CandleEvent":
+ return cls(type=raw["type"], data=Candle(**raw["data"]))
+
+
+class SignalEvent(BaseModel):
+ type: EventType = EventType.SIGNAL
+ data: Signal
+
+ def to_dict(self) -> dict:
+ return {
+ "type": self.type,
+ "data": self.data.model_dump(mode="json"),
+ }
+
+ @classmethod
+ def from_raw(cls, raw: dict) -> "SignalEvent":
+ return cls(type=raw["type"], data=Signal(**raw["data"]))
+
+
+class OrderEvent(BaseModel):
+ type: EventType = EventType.ORDER
+ data: Order
+
+ def to_dict(self) -> dict:
+ return {
+ "type": self.type,
+ "data": self.data.model_dump(mode="json"),
+ }
+
+ @classmethod
+ def from_raw(cls, raw: dict) -> "OrderEvent":
+ return cls(type=raw["type"], data=Order(**raw["data"]))
+
+
+_EVENT_TYPE_MAP = {
+ EventType.CANDLE: CandleEvent,
+ EventType.SIGNAL: SignalEvent,
+ EventType.ORDER: OrderEvent,
+}
+
+
+class Event:
+ """Dispatcher for deserializing events from raw dicts."""
+
+ @staticmethod
+ def from_dict(data: dict) -> Any:
+ event_type = EventType(data["type"])
+ cls = _EVENT_TYPE_MAP[event_type]
+ return cls.from_raw(data)
diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py
new file mode 100644
index 0000000..4cb1081
--- /dev/null
+++ b/shared/src/shared/models.py
@@ -0,0 +1,72 @@
+"""Shared Pydantic models for the trading platform."""
+import uuid
+from decimal import Decimal
+from datetime import datetime, timezone
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel, Field, computed_field
+
+
+class OrderSide(str, Enum):
+ BUY = "BUY"
+ SELL = "SELL"
+
+
+class OrderType(str, Enum):
+ MARKET = "MARKET"
+ LIMIT = "LIMIT"
+
+
+class OrderStatus(str, Enum):
+ PENDING = "PENDING"
+ FILLED = "FILLED"
+ CANCELLED = "CANCELLED"
+ FAILED = "FAILED"
+
+
+class Candle(BaseModel):
+ symbol: str
+ timeframe: str
+ open_time: datetime
+ open: Decimal
+ high: Decimal
+ low: Decimal
+ close: Decimal
+ volume: Decimal
+
+
+class Signal(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ strategy: str
+ symbol: str
+ side: OrderSide
+ price: Decimal
+ quantity: Decimal
+ reason: str
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+
+
+class Order(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ signal_id: str
+ symbol: str
+ side: OrderSide
+ type: OrderType
+ price: Decimal
+ quantity: Decimal
+ status: OrderStatus = OrderStatus.PENDING
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
+ filled_at: Optional[datetime] = None
+
+
+class Position(BaseModel):
+ symbol: str
+ quantity: Decimal
+ avg_entry_price: Decimal
+ current_price: Decimal
+
+ @computed_field
+ @property
+ def unrealized_pnl(self) -> Decimal:
+ return self.quantity * (self.current_price - self.avg_entry_price)
diff --git a/shared/tests/__init__.py b/shared/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/shared/tests/__init__.py
diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py
new file mode 100644
index 0000000..d3a3569
--- /dev/null
+++ b/shared/tests/test_broker.py
@@ -0,0 +1,66 @@
+"""Tests for the Redis broker."""
+import pytest
+import json
+from unittest.mock import AsyncMock, MagicMock, patch
+
+
+@pytest.mark.asyncio
+async def test_broker_publish():
+ """Test that publish calls xadd on the redis connection."""
+ with patch("redis.asyncio.from_url") as mock_from_url:
+ mock_redis = AsyncMock()
+ mock_from_url.return_value = mock_redis
+
+ from shared.broker import RedisBroker
+ broker = RedisBroker("redis://localhost:6379")
+ data = {"type": "CANDLE", "symbol": "BTCUSDT"}
+ await broker.publish("candles", data)
+
+ mock_redis.xadd.assert_called_once()
+ call_args = mock_redis.xadd.call_args
+ assert call_args[0][0] == "candles"
+ payload = call_args[0][1]
+ assert "payload" in payload
+ parsed = json.loads(payload["payload"])
+ assert parsed["type"] == "CANDLE"
+
+
+@pytest.mark.asyncio
+async def test_broker_subscribe_returns_messages():
+ """Test that read parses xread response correctly."""
+ with patch("redis.asyncio.from_url") as mock_from_url:
+ mock_redis = AsyncMock()
+ mock_from_url.return_value = mock_redis
+
+ payload_data = {"type": "CANDLE", "symbol": "ETHUSDT"}
+ mock_redis.xread.return_value = [
+ [
+ b"candles",
+ [
+ (b"1234567890-0", {b"payload": json.dumps(payload_data).encode()}),
+ ],
+ ]
+ ]
+
+ from shared.broker import RedisBroker
+ broker = RedisBroker("redis://localhost:6379")
+ messages = await broker.read("candles", last_id="$")
+
+ mock_redis.xread.assert_called_once()
+ assert len(messages) == 1
+ assert messages[0]["type"] == "CANDLE"
+ assert messages[0]["symbol"] == "ETHUSDT"
+
+
+@pytest.mark.asyncio
+async def test_broker_close():
+ """Test that close calls aclose on the redis connection."""
+ with patch("redis.asyncio.from_url") as mock_from_url:
+ mock_redis = AsyncMock()
+ mock_from_url.return_value = mock_redis
+
+ from shared.broker import RedisBroker
+ broker = RedisBroker("redis://localhost:6379")
+ await broker.close()
+
+ mock_redis.aclose.assert_called_once()
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
new file mode 100644
index 0000000..c31e487
--- /dev/null
+++ b/shared/tests/test_db.py
@@ -0,0 +1,70 @@
+"""Tests for the database layer."""
+import pytest
+from decimal import Decimal
+from datetime import datetime, timezone
+from unittest.mock import AsyncMock, MagicMock, patch, call
+
+
+def make_candle():
+ from shared.models import Candle
+ return Candle(
+ symbol="BTCUSDT",
+ timeframe="1m",
+ open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open=Decimal("50000"),
+ high=Decimal("51000"),
+ low=Decimal("49500"),
+ close=Decimal("50500"),
+ volume=Decimal("100"),
+ )
+
+
+@pytest.mark.asyncio
+async def test_db_init_sql_creates_tables():
+ """Verify that init_tables SQL references all required table names."""
+ with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool:
+ mock_conn = AsyncMock()
+ mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
+ mock_pool.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ # Capture the SQL that gets executed
+ executed_sqls = []
+
+ async def capture_execute(sql, *args, **kwargs):
+ executed_sqls.append(sql)
+
+ mock_conn.execute = capture_execute
+
+ from shared.db import Database
+ db = Database("postgresql://trading:trading@localhost:5432/trading")
+ db._pool = mock_pool.return_value
+ await db.init_tables()
+
+ combined_sql = " ".join(executed_sqls)
+ for table in ["candles", "signals", "orders", "trades", "positions", "portfolio_snapshots"]:
+ assert table in combined_sql, f"Table '{table}' not found in SQL"
+
+
+@pytest.mark.asyncio
+async def test_db_insert_candle():
+ """Verify that insert_candle executes INSERT INTO candles."""
+ with patch("asyncpg.create_pool", new_callable=AsyncMock) as mock_pool:
+ mock_conn = AsyncMock()
+ mock_pool.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
+ mock_pool.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ executed = []
+
+ async def capture_execute(sql, *args, **kwargs):
+ executed.append((sql, args))
+
+ mock_conn.execute = capture_execute
+
+ from shared.db import Database
+ db = Database("postgresql://trading:trading@localhost:5432/trading")
+ db._pool = mock_pool.return_value
+ candle = make_candle()
+ await db.insert_candle(candle)
+
+ assert any("INSERT INTO candles" in sql for sql, _ in executed), \
+ "Expected INSERT INTO candles"
diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py
new file mode 100644
index 0000000..4bc7981
--- /dev/null
+++ b/shared/tests/test_events.py
@@ -0,0 +1,80 @@
+"""Tests for shared event types."""
+import pytest
+from decimal import Decimal
+from datetime import datetime, timezone
+
+
+def make_candle():
+ from shared.models import Candle
+ return Candle(
+ symbol="BTCUSDT",
+ timeframe="1m",
+ open_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
+ open=Decimal("50000"),
+ high=Decimal("51000"),
+ low=Decimal("49500"),
+ close=Decimal("50500"),
+ volume=Decimal("100"),
+ )
+
+
+def make_signal():
+ from shared.models import Signal, OrderSide
+ return Signal(
+ strategy="test",
+ symbol="BTCUSDT",
+ side=OrderSide.BUY,
+ price=Decimal("50000"),
+ quantity=Decimal("0.01"),
+ reason="test signal",
+ )
+
+
+def test_candle_event_serialize():
+ """Test CandleEvent serializes to dict correctly."""
+ from shared.events import CandleEvent, EventType
+ candle = make_candle()
+ event = CandleEvent(data=candle)
+ d = event.to_dict()
+ assert d["type"] == EventType.CANDLE
+ assert d["data"]["symbol"] == "BTCUSDT"
+ assert d["data"]["timeframe"] == "1m"
+
+
+def test_candle_event_deserialize():
+ """Test CandleEvent round-trips through to_dict/from_raw."""
+ from shared.events import CandleEvent, EventType
+ candle = make_candle()
+ event = CandleEvent(data=candle)
+ d = event.to_dict()
+ restored = CandleEvent.from_raw(d)
+ assert restored.type == EventType.CANDLE
+ assert restored.data.symbol == "BTCUSDT"
+ assert restored.data.close == Decimal("50500")
+
+
+def test_signal_event_serialize():
+ """Test SignalEvent serializes to dict correctly."""
+ from shared.events import SignalEvent, EventType
+ signal = make_signal()
+ event = SignalEvent(data=signal)
+ d = event.to_dict()
+ assert d["type"] == EventType.SIGNAL
+ assert d["data"]["symbol"] == "BTCUSDT"
+ assert d["data"]["strategy"] == "test"
+
+
+def test_event_from_dict_dispatch():
+ """Test Event.from_dict dispatches to correct class."""
+ from shared.events import Event, CandleEvent, SignalEvent, EventType
+ candle = make_candle()
+ event = CandleEvent(data=candle)
+ d = event.to_dict()
+ restored = Event.from_dict(d)
+ assert isinstance(restored, CandleEvent)
+
+ signal = make_signal()
+ s_event = SignalEvent(data=signal)
+ sd = s_event.to_dict()
+ restored_s = Event.from_dict(sd)
+ assert isinstance(restored_s, SignalEvent)
diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py
new file mode 100644
index 0000000..f1d92ec
--- /dev/null
+++ b/shared/tests/test_models.py
@@ -0,0 +1,100 @@
+"""Tests for shared models and settings."""
+import os
+import pytest
+from decimal import Decimal
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+
+def test_settings_defaults():
+ """Test that Settings has correct defaults."""
+ with patch.dict(os.environ, {
+ "BINANCE_API_KEY": "test_key",
+ "BINANCE_API_SECRET": "test_secret",
+ }):
+ from shared.config import Settings
+ settings = Settings()
+ assert settings.redis_url == "redis://localhost:6379"
+ assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading"
+ assert settings.log_level == "INFO"
+ assert settings.risk_max_position_size == 0.1
+ assert settings.risk_stop_loss_pct == 5.0
+ assert settings.risk_daily_loss_limit_pct == 10.0
+ assert settings.dry_run is True
+
+
+def test_candle_creation():
+ """Test Candle model creation."""
+ from shared.models import Candle
+ now = datetime.now(timezone.utc)
+ candle = Candle(
+ symbol="BTCUSDT",
+ timeframe="1m",
+ open_time=now,
+ open=Decimal("50000.00"),
+ high=Decimal("51000.00"),
+ low=Decimal("49500.00"),
+ close=Decimal("50500.00"),
+ volume=Decimal("100.5"),
+ )
+ assert candle.symbol == "BTCUSDT"
+ assert candle.timeframe == "1m"
+ assert candle.open == Decimal("50000.00")
+ assert candle.high == Decimal("51000.00")
+ assert candle.low == Decimal("49500.00")
+ assert candle.close == Decimal("50500.00")
+ assert candle.volume == Decimal("100.5")
+
+
+def test_signal_creation():
+ """Test Signal model creation."""
+ from shared.models import Signal, OrderSide
+ signal = Signal(
+ strategy="rsi_strategy",
+ symbol="BTCUSDT",
+ side=OrderSide.BUY,
+ price=Decimal("50000.00"),
+ quantity=Decimal("0.01"),
+ reason="RSI oversold",
+ )
+ assert signal.strategy == "rsi_strategy"
+ assert signal.symbol == "BTCUSDT"
+ assert signal.side == OrderSide.BUY
+ assert signal.price == Decimal("50000.00")
+ assert signal.quantity == Decimal("0.01")
+ assert signal.reason == "RSI oversold"
+ assert signal.id is not None
+ assert signal.created_at is not None
+
+
+def test_order_creation():
+ """Test Order model creation with defaults."""
+ from shared.models import Order, OrderSide, OrderType, OrderStatus
+ import uuid
+ signal_id = str(uuid.uuid4())
+ order = Order(
+ signal_id=signal_id,
+ symbol="BTCUSDT",
+ side=OrderSide.BUY,
+ type=OrderType.MARKET,
+ price=Decimal("50000.00"),
+ quantity=Decimal("0.01"),
+ )
+ assert order.id is not None
+ assert order.signal_id == signal_id
+ assert order.status == OrderStatus.PENDING
+ assert order.filled_at is None
+ assert order.created_at is not None
+
+
+def test_position_unrealized_pnl():
+ """Test Position unrealized_pnl computed property."""
+ from shared.models import Position
+ position = Position(
+ symbol="BTCUSDT",
+ quantity=Decimal("0.1"),
+ avg_entry_price=Decimal("50000"),
+ current_price=Decimal("51000"),
+ )
+ # 0.1 * (51000 - 50000) = 100
+ assert position.unrealized_pnl == Decimal("100")