summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/broker.py5
-rw-r--r--shared/src/shared/config.py1
-rw-r--r--shared/src/shared/db.py5
-rw-r--r--shared/src/shared/events.py1
-rw-r--r--shared/src/shared/healthcheck.py1
-rw-r--r--shared/src/shared/logging.py1
-rw-r--r--shared/src/shared/metrics.py1
-rw-r--r--shared/src/shared/models.py1
-rw-r--r--shared/src/shared/notifier.py11
-rw-r--r--shared/src/shared/resilience.py2
-rw-r--r--shared/src/shared/sa_models.py16
-rw-r--r--shared/tests/test_broker.py6
-rw-r--r--shared/tests/test_db.py18
-rw-r--r--shared/tests/test_events.py10
-rw-r--r--shared/tests/test_healthcheck.py3
-rw-r--r--shared/tests/test_logging.py1
-rw-r--r--shared/tests/test_metrics.py3
-rw-r--r--shared/tests/test_models.py18
-rw-r--r--shared/tests/test_notifier.py4
-rw-r--r--shared/tests/test_resilience.py3
-rw-r--r--shared/tests/test_sa_models.py9
21 files changed, 71 insertions, 49 deletions
diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py
index 0f87b06..9c6c4c6 100644
--- a/shared/src/shared/broker.py
+++ b/shared/src/shared/broker.py
@@ -1,4 +1,5 @@
"""Redis Streams broker for the trading platform."""
+
import json
from typing import Any
@@ -24,9 +25,7 @@ class RedisBroker:
block: int = 0,
) -> list[dict[str, Any]]:
"""Read messages from a Redis stream."""
- results = await self._redis.xread(
- {stream: last_id}, count=count, block=block
- )
+ results = await self._redis.xread({stream: last_id}, count=count, block=block)
messages = []
if results:
for _stream, entries in results:
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
index 511654f..47bc2b1 100644
--- a/shared/src/shared/config.py
+++ b/shared/src/shared/config.py
@@ -1,4 +1,5 @@
"""Shared configuration settings for the trading platform."""
+
from pydantic_settings import BaseSettings
diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py
index 95e487e..f9b7f56 100644
--- a/shared/src/shared/db.py
+++ b/shared/src/shared/db.py
@@ -1,4 +1,5 @@
"""Database layer using SQLAlchemy 2.0 async ORM for the trading platform."""
+
from datetime import datetime
from typing import Optional
@@ -107,9 +108,7 @@ class Database:
await session.execute(stmt)
await session.commit()
- async def get_candles(
- self, symbol: str, timeframe: str, limit: int = 500
- ) -> list[dict]:
+ async def get_candles(self, symbol: str, timeframe: str, limit: int = 500) -> list[dict]:
"""Retrieve candles ordered by open_time descending."""
stmt = (
select(CandleRow)
diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py
index 1db2bee..72f8865 100644
--- a/shared/src/shared/events.py
+++ b/shared/src/shared/events.py
@@ -1,4 +1,5 @@
"""Event types and serialization for the trading platform."""
+
from enum import Enum
from typing import Any
diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py
index 8294294..be02712 100644
--- a/shared/src/shared/healthcheck.py
+++ b/shared/src/shared/healthcheck.py
@@ -1,4 +1,5 @@
"""Health check HTTP server with Prometheus metrics endpoint."""
+
from __future__ import annotations
import time
diff --git a/shared/src/shared/logging.py b/shared/src/shared/logging.py
index b873eaf..9e42cdc 100644
--- a/shared/src/shared/logging.py
+++ b/shared/src/shared/logging.py
@@ -1,4 +1,5 @@
"""Structured logging configuration using structlog."""
+
from __future__ import annotations
import logging
diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py
index 3b00c5d..cd239f3 100644
--- a/shared/src/shared/metrics.py
+++ b/shared/src/shared/metrics.py
@@ -1,4 +1,5 @@
"""Prometheus metrics for trading platform services."""
+
from __future__ import annotations
from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY
diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py
index 4cb1081..0e8ca44 100644
--- a/shared/src/shared/models.py
+++ b/shared/src/shared/models.py
@@ -1,4 +1,5 @@
"""Shared Pydantic models for the trading platform."""
+
import uuid
from decimal import Decimal
from datetime import datetime, timezone
diff --git a/shared/src/shared/notifier.py b/shared/src/shared/notifier.py
index de86f87..f03919c 100644
--- a/shared/src/shared/notifier.py
+++ b/shared/src/shared/notifier.py
@@ -1,4 +1,5 @@
"""Telegram notification service for the trading platform."""
+
import asyncio
import logging
from decimal import Decimal
@@ -63,9 +64,7 @@ class TelegramNotifier:
body,
)
except Exception:
- logger.exception(
- "Telegram send failed (attempt %d/%d)", attempt, MAX_RETRIES
- )
+ logger.exception("Telegram send failed (attempt %d/%d)", attempt, MAX_RETRIES)
if attempt < MAX_RETRIES:
await asyncio.sleep(attempt)
@@ -96,11 +95,7 @@ class TelegramNotifier:
async def send_error(self, error: str, service: str) -> None:
"""Format and send an error alert."""
- msg = (
- "<b>🚨 Error Alert</b>\n"
- f"Service: <b>{service}</b>\n"
- f"Error: {error}"
- )
+ msg = f"<b>🚨 Error Alert</b>\nService: <b>{service}</b>\nError: {error}"
await self.send(msg)
async def send_daily_summary(
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
index d4e963b..e43fd21 100644
--- a/shared/src/shared/resilience.py
+++ b/shared/src/shared/resilience.py
@@ -35,7 +35,7 @@ def retry_with_backoff(
except Exception as exc:
last_exc = exc
if attempt < max_retries:
- delay = min(base_delay * (2 ** attempt), max_delay)
+ delay = min(base_delay * (2**attempt), max_delay)
jitter = delay * random.uniform(0, 0.5)
total_delay = delay + jitter
logger.warning(
diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py
index 0537846..8386ba8 100644
--- a/shared/src/shared/sa_models.py
+++ b/shared/src/shared/sa_models.py
@@ -41,17 +41,13 @@ class OrderRow(Base):
__tablename__ = "orders"
id: Mapped[str] = mapped_column(Text, primary_key=True)
- signal_id: Mapped[str | None] = mapped_column(
- Text, ForeignKey("signals.id")
- )
+ signal_id: Mapped[str | None] = mapped_column(Text, ForeignKey("signals.id"))
symbol: Mapped[str] = mapped_column(Text, nullable=False)
side: Mapped[str] = mapped_column(Text, nullable=False)
type: Mapped[str] = mapped_column(Text, nullable=False)
price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
- status: Mapped[str] = mapped_column(
- Text, nullable=False, server_default="PENDING"
- )
+ status: Mapped[str] = mapped_column(Text, nullable=False, server_default="PENDING")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
filled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
@@ -60,16 +56,12 @@ class TradeRow(Base):
__tablename__ = "trades"
id: Mapped[str] = mapped_column(Text, primary_key=True)
- order_id: Mapped[str | None] = mapped_column(
- Text, ForeignKey("orders.id")
- )
+ order_id: Mapped[str | None] = mapped_column(Text, ForeignKey("orders.id"))
symbol: Mapped[str] = mapped_column(Text, nullable=False)
side: Mapped[str] = mapped_column(Text, nullable=False)
price: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False)
- fee: Mapped[Decimal] = mapped_column(
- Numeric, nullable=False, server_default="0"
- )
+ fee: Mapped[Decimal] = mapped_column(Numeric, nullable=False, server_default="0")
traded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py
index d3a3569..ea8b148 100644
--- a/shared/tests/test_broker.py
+++ b/shared/tests/test_broker.py
@@ -1,7 +1,8 @@
"""Tests for the Redis broker."""
+
import pytest
import json
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import AsyncMock, patch
@pytest.mark.asyncio
@@ -12,6 +13,7 @@ async def test_broker_publish():
mock_from_url.return_value = mock_redis
from shared.broker import RedisBroker
+
broker = RedisBroker("redis://localhost:6379")
data = {"type": "CANDLE", "symbol": "BTCUSDT"}
await broker.publish("candles", data)
@@ -43,6 +45,7 @@ async def test_broker_subscribe_returns_messages():
]
from shared.broker import RedisBroker
+
broker = RedisBroker("redis://localhost:6379")
messages = await broker.read("candles", last_id="$")
@@ -60,6 +63,7 @@ async def test_broker_close():
mock_from_url.return_value = mock_redis
from shared.broker import RedisBroker
+
broker = RedisBroker("redis://localhost:6379")
await broker.close()
diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py
index 45d5dcd..b9a9d56 100644
--- a/shared/tests/test_db.py
+++ b/shared/tests/test_db.py
@@ -1,12 +1,14 @@
"""Tests for the SQLAlchemy async database layer."""
+
import pytest
from decimal import Decimal
from datetime import datetime, timezone
-from unittest.mock import AsyncMock, MagicMock, patch, call
+from unittest.mock import AsyncMock, MagicMock, patch
def make_candle():
from shared.models import Candle
+
return Candle(
symbol="BTCUSDT",
timeframe="1m",
@@ -21,6 +23,7 @@ def make_candle():
def make_signal():
from shared.models import Signal, OrderSide
+
return Signal(
id="sig-1",
strategy="ma_cross",
@@ -35,6 +38,7 @@ def make_signal():
def make_order():
from shared.models import Order, OrderSide, OrderType, OrderStatus
+
return Order(
id="ord-1",
signal_id="sig-1",
@@ -51,21 +55,25 @@ def make_order():
class TestDatabaseConstructor:
def test_stores_url(self):
from shared.db import Database
+
db = Database("postgresql://user:pass@localhost/db")
assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db"
def test_converts_url_prefix(self):
from shared.db import Database
+
db = Database("postgresql://host/db")
assert db._database_url.startswith("postgresql+asyncpg://")
def test_keeps_asyncpg_prefix(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
assert db._database_url == "postgresql+asyncpg://host/db"
def test_get_session_exists(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
assert hasattr(db, "get_session")
@@ -74,6 +82,7 @@ class TestDatabaseConnect:
@pytest.mark.asyncio
async def test_connect_creates_engine_and_tables(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_conn = AsyncMock()
@@ -94,6 +103,7 @@ class TestDatabaseConnect:
@pytest.mark.asyncio
async def test_init_tables_is_alias_for_connect(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_conn = AsyncMock()
@@ -116,6 +126,7 @@ class TestDatabaseClose:
@pytest.mark.asyncio
async def test_close_disposes_engine(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_engine = AsyncMock()
db._engine = mock_engine
@@ -127,6 +138,7 @@ class TestInsertCandle:
@pytest.mark.asyncio
async def test_insert_candle_uses_merge_and_commit(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -147,6 +159,7 @@ class TestInsertSignal:
@pytest.mark.asyncio
async def test_insert_signal_uses_add_and_commit(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -167,6 +180,7 @@ class TestInsertOrder:
@pytest.mark.asyncio
async def test_insert_order_uses_add_and_commit(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -188,6 +202,7 @@ class TestUpdateOrderStatus:
async def test_update_order_status_uses_execute_and_commit(self):
from shared.db import Database
from shared.models import OrderStatus
+
db = Database("postgresql+asyncpg://host/db")
mock_session = AsyncMock()
@@ -207,6 +222,7 @@ class TestGetCandles:
@pytest.mark.asyncio
async def test_get_candles_returns_list_of_dicts(self):
from shared.db import Database
+
db = Database("postgresql+asyncpg://host/db")
# Create a mock row that behaves like a SA result row
diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py
index 4bc7981..ab7792b 100644
--- a/shared/tests/test_events.py
+++ b/shared/tests/test_events.py
@@ -1,11 +1,12 @@
"""Tests for shared event types."""
-import pytest
+
from decimal import Decimal
from datetime import datetime, timezone
def make_candle():
from shared.models import Candle
+
return Candle(
symbol="BTCUSDT",
timeframe="1m",
@@ -20,6 +21,7 @@ def make_candle():
def make_signal():
from shared.models import Signal, OrderSide
+
return Signal(
strategy="test",
symbol="BTCUSDT",
@@ -33,6 +35,7 @@ def make_signal():
def test_candle_event_serialize():
"""Test CandleEvent serializes to dict correctly."""
from shared.events import CandleEvent, EventType
+
candle = make_candle()
event = CandleEvent(data=candle)
d = event.to_dict()
@@ -44,6 +47,7 @@ def test_candle_event_serialize():
def test_candle_event_deserialize():
"""Test CandleEvent round-trips through to_dict/from_raw."""
from shared.events import CandleEvent, EventType
+
candle = make_candle()
event = CandleEvent(data=candle)
d = event.to_dict()
@@ -56,6 +60,7 @@ def test_candle_event_deserialize():
def test_signal_event_serialize():
"""Test SignalEvent serializes to dict correctly."""
from shared.events import SignalEvent, EventType
+
signal = make_signal()
event = SignalEvent(data=signal)
d = event.to_dict()
@@ -66,7 +71,8 @@ def test_signal_event_serialize():
def test_event_from_dict_dispatch():
"""Test Event.from_dict dispatches to correct class."""
- from shared.events import Event, CandleEvent, SignalEvent, EventType
+ from shared.events import Event, CandleEvent, SignalEvent
+
candle = make_candle()
event = CandleEvent(data=candle)
d = event.to_dict()
diff --git a/shared/tests/test_healthcheck.py b/shared/tests/test_healthcheck.py
index 1af86b1..6970a8f 100644
--- a/shared/tests/test_healthcheck.py
+++ b/shared/tests/test_healthcheck.py
@@ -1,6 +1,6 @@
"""Tests for health check server."""
+
import pytest
-import asyncio
from prometheus_client import CollectorRegistry
@@ -11,6 +11,7 @@ def registry():
def make_server(service_name="test-service", port=8080, registry=None):
from shared.healthcheck import HealthCheckServer
+
return HealthCheckServer(service_name, port=port, registry=registry)
diff --git a/shared/tests/test_logging.py b/shared/tests/test_logging.py
index 4abd254..2ffddcd 100644
--- a/shared/tests/test_logging.py
+++ b/shared/tests/test_logging.py
@@ -1,4 +1,5 @@
"""Tests for shared structured logging module."""
+
import io
import json
import logging
diff --git a/shared/tests/test_metrics.py b/shared/tests/test_metrics.py
index 079f01c..3fd72a7 100644
--- a/shared/tests/test_metrics.py
+++ b/shared/tests/test_metrics.py
@@ -1,10 +1,11 @@
"""Tests for Prometheus metrics utilities."""
-import pytest
+
from prometheus_client import CollectorRegistry
def make_metrics(service_name="test-service", registry=None):
from shared.metrics import ServiceMetrics
+
return ServiceMetrics(service_name, registry=registry)
diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py
index f1d92ec..25ab4c9 100644
--- a/shared/tests/test_models.py
+++ b/shared/tests/test_models.py
@@ -1,6 +1,6 @@
"""Tests for shared models and settings."""
+
import os
-import pytest
from decimal import Decimal
from datetime import datetime, timezone
from unittest.mock import patch
@@ -8,11 +8,15 @@ from unittest.mock import patch
def test_settings_defaults():
"""Test that Settings has correct defaults."""
- with patch.dict(os.environ, {
- "BINANCE_API_KEY": "test_key",
- "BINANCE_API_SECRET": "test_secret",
- }):
+ with patch.dict(
+ os.environ,
+ {
+ "BINANCE_API_KEY": "test_key",
+ "BINANCE_API_SECRET": "test_secret",
+ },
+ ):
from shared.config import Settings
+
settings = Settings()
assert settings.redis_url == "redis://localhost:6379"
assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading"
@@ -26,6 +30,7 @@ def test_settings_defaults():
def test_candle_creation():
"""Test Candle model creation."""
from shared.models import Candle
+
now = datetime.now(timezone.utc)
candle = Candle(
symbol="BTCUSDT",
@@ -49,6 +54,7 @@ def test_candle_creation():
def test_signal_creation():
"""Test Signal model creation."""
from shared.models import Signal, OrderSide
+
signal = Signal(
strategy="rsi_strategy",
symbol="BTCUSDT",
@@ -71,6 +77,7 @@ def test_order_creation():
"""Test Order model creation with defaults."""
from shared.models import Order, OrderSide, OrderType, OrderStatus
import uuid
+
signal_id = str(uuid.uuid4())
order = Order(
signal_id=signal_id,
@@ -90,6 +97,7 @@ def test_order_creation():
def test_position_unrealized_pnl():
"""Test Position unrealized_pnl computed property."""
from shared.models import Position
+
position = Position(
symbol="BTCUSDT",
quantity=Decimal("0.1"),
diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py
index 09e731a..3d29830 100644
--- a/shared/tests/test_notifier.py
+++ b/shared/tests/test_notifier.py
@@ -1,12 +1,10 @@
"""Tests for Telegram notification service."""
-import os
+
import uuid
from decimal import Decimal
-from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-import pytest_asyncio
from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position
from shared.notifier import TelegramNotifier
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
index 514bcc2..e287777 100644
--- a/shared/tests/test_resilience.py
+++ b/shared/tests/test_resilience.py
@@ -1,5 +1,5 @@
"""Tests for retry with backoff and circuit breaker."""
-import asyncio
+
import time
import pytest
@@ -63,6 +63,7 @@ async def test_retry_raises_after_max_retries():
@pytest.mark.asyncio
async def test_retry_respects_max_delay():
"""Backoff should be capped at max_delay."""
+
@retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
async def always_fail():
raise RuntimeError("fail")
diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py
index de994c5..67c3c82 100644
--- a/shared/tests/test_sa_models.py
+++ b/shared/tests/test_sa_models.py
@@ -1,6 +1,5 @@
"""Tests for SQLAlchemy ORM models."""
-import pytest
from sqlalchemy import inspect
@@ -117,9 +116,7 @@ class TestOrderRow:
from shared.sa_models import OrderRow
table = OrderRow.__table__
- fk_cols = {
- fk.parent.name: fk.target_fullname for fk in table.foreign_keys
- }
+ fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys}
assert fk_cols == {"signal_id": "signals.id"}
@@ -157,9 +154,7 @@ class TestTradeRow:
from shared.sa_models import TradeRow
table = TradeRow.__table__
- fk_cols = {
- fk.parent.name: fk.target_fullname for fk in table.foreign_keys
- }
+ fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys}
assert fk_cols == {"order_id": "orders.id"}