diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:24:30 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:24:30 +0900 |
| commit | 100aa624ad3f8ad466a95f9da8af30f31f77cc9c (patch) | |
| tree | ef81b9f37872ed462a1f84ea238a130f758782d2 | |
| parent | 73eaf704584e5bf3c4499ccdd574af87304e1e5f (diff) | |
fix: resolve lint issues and final integration fixes
- Fix ambiguous variable name in binance_rest.py
- Remove unused volumes variable in volume_profile_strategy.py
- Fix import ordering in backtester main.py and test_metrics.py
- Auto-format all files with ruff
66 files changed, 268 insertions, 164 deletions
diff --git a/cli/src/trading_cli/commands/backtest.py b/cli/src/trading_cli/commands/backtest.py index 40617b6..0f0cdbe 100644 --- a/cli/src/trading_cli/commands/backtest.py +++ b/cli/src/trading_cli/commands/backtest.py @@ -16,7 +16,9 @@ def backtest(): def run(strategy, symbol, from_date, to_date, balance): """Run a backtest for a strategy.""" to_label = to_date or "now" - click.echo(f"Running backtest: strategy={strategy}, symbol={symbol}, {from_date} → {to_label}, balance={balance}...") + click.echo( + f"Running backtest: strategy={strategy}, symbol={symbol}, {from_date} → {to_label}, balance={balance}..." + ) @backtest.command() diff --git a/cli/src/trading_cli/commands/data.py b/cli/src/trading_cli/commands/data.py index 1fa5e30..25d1693 100644 --- a/cli/src/trading_cli/commands/data.py +++ b/cli/src/trading_cli/commands/data.py @@ -22,7 +22,11 @@ def collect(symbol, timeframe): @click.option("--limit", default=1000, show_default=True, help="Number of candles to fetch") def history(symbol, timeframe, since, limit): """Download historical market data for a symbol.""" - click.echo(f"Downloading {limit} {timeframe} candles for {symbol}" + (f" since {since}" if since else "") + "...") + click.echo( + f"Downloading {limit} {timeframe} candles for {symbol}" + + (f" since {since}" if since else "") + + "..." + ) @data.command("list") diff --git a/services/backtester/src/backtester/config.py b/services/backtester/src/backtester/config.py index bfbc196..5a912f3 100644 --- a/services/backtester/src/backtester/config.py +++ b/services/backtester/src/backtester/config.py @@ -1,4 +1,5 @@ """Configuration for the backtester service.""" + from pydantic_settings import BaseSettings diff --git a/services/backtester/src/backtester/engine.py b/services/backtester/src/backtester/engine.py index 386309b..0441011 100644 --- a/services/backtester/src/backtester/engine.py +++ b/services/backtester/src/backtester/engine.py @@ -1,4 +1,5 @@ """Backtesting engine that runs strategies against historical candle data.""" + from __future__ import annotations from dataclasses import dataclass, field @@ -98,9 +99,7 @@ class BacktestEngine: ) for t in simulator.trades ] - detailed = compute_detailed_metrics( - trade_records, self._initial_balance, final_balance - ) + detailed = compute_detailed_metrics(trade_records, self._initial_balance, final_balance) return BacktestResult( strategy_name=self._strategy.name, diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py index ab69ee1..c9b3890 100644 --- a/services/backtester/src/backtester/main.py +++ b/services/backtester/src/backtester/main.py @@ -1,22 +1,21 @@ """Main entry point for the backtester service.""" -import sys + import os +import sys from decimal import Decimal +from shared.db import Database # noqa: E402 +from shared.models import Candle # noqa: E402 + +from backtester.config import BacktestConfig # noqa: E402 +from backtester.engine import BacktestEngine # noqa: E402 +from backtester.reporter import format_report # noqa: E402 + # Allow importing strategies from the strategy-engine service -_STRATEGY_ENGINE_PATH = os.path.join( - os.path.dirname(__file__), "../../../../strategy-engine" -) +_STRATEGY_ENGINE_PATH = os.path.join(os.path.dirname(__file__), "../../../../strategy-engine") if _STRATEGY_ENGINE_PATH not in sys.path: sys.path.insert(0, _STRATEGY_ENGINE_PATH) -from shared.db import Database -from shared.models import Candle - -from backtester.config import BacktestConfig -from backtester.engine import BacktestEngine -from backtester.reporter import format_report - async def run_backtest() -> str: """Load strategy, fetch candles, run backtest, and return a formatted report.""" @@ -35,9 +34,7 @@ async def run_backtest() -> str: strategy = strategy_cls() strategy.configure({}) except Exception as exc: - raise RuntimeError( - f"Failed to load strategy '{config.strategy_name}': {exc}" - ) from exc + raise RuntimeError(f"Failed to load strategy '{config.strategy_name}': {exc}") from exc db = Database(config.database_url) await db.connect() diff --git a/services/backtester/src/backtester/metrics.py b/services/backtester/src/backtester/metrics.py index 15be0e6..caf8477 100644 --- a/services/backtester/src/backtester/metrics.py +++ b/services/backtester/src/backtester/metrics.py @@ -1,4 +1,5 @@ """Detailed backtest metrics: Sharpe, Sortino, drawdown, and more.""" + from __future__ import annotations import math @@ -87,7 +88,9 @@ def compute_detailed_metrics( ) pairs = _pair_trades(trades) - total_return = float(final_balance - initial_balance) / float(initial_balance) if initial_balance else 0.0 + total_return = ( + float(final_balance - initial_balance) / float(initial_balance) if initial_balance else 0.0 + ) if not pairs: return DetailedMetrics( @@ -114,7 +117,9 @@ def compute_detailed_metrics( gross_profit = sum(p["pnl"] for p in wins) gross_loss = abs(sum(p["pnl"] for p in losses)) - profit_factor = gross_profit / gross_loss if gross_loss > 0 else float("inf") if gross_profit > 0 else 0.0 + profit_factor = ( + gross_profit / gross_loss if gross_loss > 0 else float("inf") if gross_profit > 0 else 0.0 + ) avg_win = gross_profit / winning_trades if winning_trades else 0.0 avg_loss = gross_loss / losing_trades if losing_trades else 0.0 @@ -123,7 +128,11 @@ def compute_detailed_metrics( # Holding periods holding_periods = [p["holding_period"] for p in pairs] - avg_holding = sum(holding_periods, timedelta(0)) / len(holding_periods) if holding_periods else timedelta(0) + avg_holding = ( + sum(holding_periods, timedelta(0)) / len(holding_periods) + if holding_periods + else timedelta(0) + ) # Build equity curve from pairs init_bal = float(initial_balance) @@ -153,7 +162,11 @@ def compute_detailed_metrics( max_dd = dd # Duration: use pair exit times if i <= len(pairs) and dd_start_idx < len(pairs): - start_time = pairs[dd_start_idx]["exit_time"] if dd_start_idx < len(pairs) else pairs[0]["entry_time"] + start_time = ( + pairs[dd_start_idx]["exit_time"] + if dd_start_idx < len(pairs) + else pairs[0]["entry_time"] + ) end_time = pairs[i - 1]["exit_time"] max_dd_duration = end_time - start_time if end_time > start_time else timedelta(0) @@ -171,7 +184,7 @@ def compute_detailed_metrics( if len(returns) > 1: mean_r = sum(returns) / len(returns) downside = [min(r, 0.0) for r in returns] - downside_var = sum(d ** 2 for d in downside) / (len(downside) - 1) + downside_var = sum(d**2 for d in downside) / (len(downside) - 1) downside_std = math.sqrt(downside_var) sortino = (mean_r / downside_std * math.sqrt(365)) if downside_std > 0 else 0.0 else: diff --git a/services/backtester/src/backtester/reporter.py b/services/backtester/src/backtester/reporter.py index e9e9936..cc5d67b 100644 --- a/services/backtester/src/backtester/reporter.py +++ b/services/backtester/src/backtester/reporter.py @@ -1,4 +1,5 @@ """Report formatting for backtest results.""" + from __future__ import annotations import csv diff --git a/services/backtester/src/backtester/simulator.py b/services/backtester/src/backtester/simulator.py index b897c5a..33eeb76 100644 --- a/services/backtester/src/backtester/simulator.py +++ b/services/backtester/src/backtester/simulator.py @@ -1,4 +1,5 @@ """Simulated order executor for backtesting.""" + from dataclasses import dataclass, field from datetime import datetime, timezone from decimal import Decimal diff --git a/services/backtester/tests/test_engine.py b/services/backtester/tests/test_engine.py index 1a25e1c..6962477 100644 --- a/services/backtester/tests/test_engine.py +++ b/services/backtester/tests/test_engine.py @@ -1,13 +1,13 @@ """Tests for the BacktestEngine.""" + from datetime import datetime, timezone from decimal import Decimal from unittest.mock import MagicMock -import pytest from shared.models import Candle, Signal, OrderSide -from backtester.engine import BacktestEngine, BacktestResult +from backtester.engine import BacktestEngine def make_candle(symbol: str, price: float, timeframe: str = "1h") -> Candle: diff --git a/services/backtester/tests/test_metrics.py b/services/backtester/tests/test_metrics.py index b222b8a..68bc0b5 100644 --- a/services/backtester/tests/test_metrics.py +++ b/services/backtester/tests/test_metrics.py @@ -1,4 +1,6 @@ """Tests for detailed backtest metrics.""" + +import math from datetime import datetime, timedelta, timezone from decimal import Decimal @@ -21,9 +23,9 @@ def test_compute_metrics_basic(): """Two round-trip trades: 1 win, 1 loss. Verify counts and win_rate.""" trades = [ _make_trade("BUY", "100", 0), - _make_trade("SELL", "120", 10), # win: +20 + _make_trade("SELL", "120", 10), # win: +20 _make_trade("BUY", "130", 20), - _make_trade("SELL", "110", 30), # loss: -20 + _make_trade("SELL", "110", 30), # loss: -20 ] metrics = compute_detailed_metrics(trades, Decimal("10000"), Decimal("10000")) @@ -37,9 +39,9 @@ def test_compute_metrics_profit_factor(): """Verify profit_factor = gross_profit / gross_loss.""" trades = [ _make_trade("BUY", "100", 0), - _make_trade("SELL", "150", 10), # win: +50 + _make_trade("SELL", "150", 10), # win: +50 _make_trade("BUY", "150", 20), - _make_trade("SELL", "130", 30), # loss: -20 + _make_trade("SELL", "130", 30), # loss: -20 ] metrics = compute_detailed_metrics(trades, Decimal("10000"), Decimal("10030")) @@ -51,9 +53,9 @@ def test_compute_metrics_max_drawdown(): """Max drawdown should be > 0 when there is a losing trade after a peak.""" trades = [ _make_trade("BUY", "100", 0), - _make_trade("SELL", "150", 10), # win: equity goes up + _make_trade("SELL", "150", 10), # win: equity goes up _make_trade("BUY", "150", 20), - _make_trade("SELL", "120", 30), # loss: equity drops + _make_trade("SELL", "120", 30), # loss: equity drops ] metrics = compute_detailed_metrics(trades, Decimal("10000"), Decimal("10020")) @@ -91,6 +93,3 @@ def test_compute_metrics_empty_trades(): assert metrics.calmar_ratio == 0.0 assert metrics.max_drawdown == 0.0 assert metrics.monthly_returns == {} - - -import math diff --git a/services/backtester/tests/test_reporter.py b/services/backtester/tests/test_reporter.py index aef3fc6..2ea49c0 100644 --- a/services/backtester/tests/test_reporter.py +++ b/services/backtester/tests/test_reporter.py @@ -1,4 +1,5 @@ """Tests for the report formatter.""" + import json from datetime import timedelta from decimal import Decimal diff --git a/services/backtester/tests/test_simulator.py b/services/backtester/tests/test_simulator.py index 9d8b23e..e8c80ec 100644 --- a/services/backtester/tests/test_simulator.py +++ b/services/backtester/tests/test_simulator.py @@ -1,9 +1,9 @@ """Tests for the OrderSimulator.""" + from decimal import Decimal -import pytest -from shared.models import Signal, OrderSide, OrderType +from shared.models import Signal, OrderSide from backtester.simulator import OrderSimulator diff --git a/services/data-collector/src/data_collector/binance_rest.py b/services/data-collector/src/data_collector/binance_rest.py index af0eb77..eaf4e30 100644 --- a/services/data-collector/src/data_collector/binance_rest.py +++ b/services/data-collector/src/data_collector/binance_rest.py @@ -1,4 +1,5 @@ """Binance REST API helpers for fetching historical candle data.""" + from datetime import datetime, timezone from decimal import Decimal @@ -35,7 +36,7 @@ async def fetch_historical_candles( candles: list[Candle] = [] for row in rows: - ts_ms, o, h, l, c, v = row + ts_ms, o, h, low, c, v = row open_time = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) candles.append( Candle( @@ -44,7 +45,7 @@ async def fetch_historical_candles( open_time=open_time, open=Decimal(str(o)), high=Decimal(str(h)), - low=Decimal(str(l)), + low=Decimal(str(low)), close=Decimal(str(c)), volume=Decimal(str(v)), ) diff --git a/services/data-collector/src/data_collector/binance_ws.py b/services/data-collector/src/data_collector/binance_ws.py index 7a4bad2..a1c81d6 100644 --- a/services/data-collector/src/data_collector/binance_ws.py +++ b/services/data-collector/src/data_collector/binance_ws.py @@ -1,4 +1,5 @@ """Binance WebSocket client for real-time kline/candle data.""" + import asyncio import json import logging @@ -96,9 +97,7 @@ class BinanceWebSocket: except Exception: if not self._running: break - logger.warning( - "WebSocket disconnected. Reconnecting in %ds…", RECONNECT_DELAY - ) + logger.warning("WebSocket disconnected. Reconnecting in %ds…", RECONNECT_DELAY) await asyncio.sleep(RECONNECT_DELAY) def stop(self) -> None: diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py index c778503..170e8b1 100644 --- a/services/data-collector/src/data_collector/main.py +++ b/services/data-collector/src/data_collector/main.py @@ -1,4 +1,5 @@ """Data Collector Service entry point.""" + import asyncio from shared.broker import RedisBroker @@ -18,7 +19,9 @@ async def run() -> None: config = CollectorConfig() log = setup_logging("data-collector", config.log_level, config.log_format) metrics = ServiceMetrics("data_collector") - notifier = TelegramNotifier(bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id) + notifier = TelegramNotifier( + bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + ) db = Database(config.database_url) await db.connect() @@ -28,7 +31,12 @@ async def run() -> None: storage = CandleStorage(db=db, broker=broker) async def on_candle(candle): - log.info("candle_received", symbol=candle.symbol, timeframe=candle.timeframe, open_time=str(candle.open_time)) + log.info( + "candle_received", + symbol=candle.symbol, + timeframe=candle.timeframe, + open_time=str(candle.open_time), + ) await storage.store(candle) metrics.events_processed.labels(service="data-collector", event_type="candle").inc() diff --git a/services/data-collector/src/data_collector/storage.py b/services/data-collector/src/data_collector/storage.py index 1e40b82..aeeaaed 100644 --- a/services/data-collector/src/data_collector/storage.py +++ b/services/data-collector/src/data_collector/storage.py @@ -1,4 +1,5 @@ """Candle storage: persists to DB and publishes to Redis.""" + from shared.events import CandleEvent from shared.models import Candle diff --git a/services/data-collector/tests/test_binance_rest.py b/services/data-collector/tests/test_binance_rest.py index 695dcf9..bf88210 100644 --- a/services/data-collector/tests/test_binance_rest.py +++ b/services/data-collector/tests/test_binance_rest.py @@ -1,4 +1,5 @@ """Tests for binance_rest module.""" + import pytest from decimal import Decimal from unittest.mock import AsyncMock, MagicMock @@ -19,9 +20,7 @@ async def test_fetch_historical_candles_parses_response(): ] ) - candles = await fetch_historical_candles( - mock_exchange, "BTC/USDT", "1m", since=ts, limit=500 - ) + candles = await fetch_historical_candles(mock_exchange, "BTC/USDT", "1m", since=ts, limit=500) assert len(candles) == 2 @@ -35,9 +34,7 @@ async def test_fetch_historical_candles_parses_response(): assert c.close == Decimal("30050.0") assert c.volume == Decimal("1.5") - mock_exchange.fetch_ohlcv.assert_called_once_with( - "BTC/USDT", "1m", since=ts, limit=500 - ) + mock_exchange.fetch_ohlcv.assert_called_once_with("BTC/USDT", "1m", since=ts, limit=500) @pytest.mark.asyncio @@ -46,8 +43,6 @@ async def test_fetch_historical_candles_empty_response(): mock_exchange = MagicMock() mock_exchange.fetch_ohlcv = AsyncMock(return_value=[]) - candles = await fetch_historical_candles( - mock_exchange, "BTC/USDT", "1m", since=1700000000000 - ) + candles = await fetch_historical_candles(mock_exchange, "BTC/USDT", "1m", since=1700000000000) assert candles == [] diff --git a/services/data-collector/tests/test_storage.py b/services/data-collector/tests/test_storage.py index 6b27414..be85578 100644 --- a/services/data-collector/tests/test_storage.py +++ b/services/data-collector/tests/test_storage.py @@ -1,4 +1,5 @@ """Tests for storage module.""" + import pytest from decimal import Decimal from datetime import datetime, timezone diff --git a/services/order-executor/src/order_executor/config.py b/services/order-executor/src/order_executor/config.py index 856045f..6542a31 100644 --- a/services/order-executor/src/order_executor/config.py +++ b/services/order-executor/src/order_executor/config.py @@ -1,4 +1,5 @@ """Order Executor configuration.""" + from shared.config import Settings diff --git a/services/order-executor/src/order_executor/executor.py b/services/order-executor/src/order_executor/executor.py index 099520d..80f441d 100644 --- a/services/order-executor/src/order_executor/executor.py +++ b/services/order-executor/src/order_executor/executor.py @@ -1,4 +1,5 @@ """Order execution logic.""" + import structlog from datetime import datetime, timezone from decimal import Decimal @@ -7,7 +8,7 @@ from typing import Any, Optional from shared.broker import RedisBroker from shared.db import Database from shared.events import OrderEvent -from shared.models import Order, OrderSide, OrderStatus, OrderType, Signal +from shared.models import Order, OrderStatus, OrderType, Signal from shared.notifier import TelegramNotifier from order_executor.risk_manager import RiskManager @@ -58,9 +59,7 @@ class OrderExecutor: ) if not result.allowed: - logger.warning( - "risk_check_rejected", signal_id=str(signal.id), reason=result.reason - ) + logger.warning("risk_check_rejected", signal_id=str(signal.id), reason=result.reason) return None # Build the order model @@ -77,7 +76,12 @@ class OrderExecutor: if self.dry_run: order.status = OrderStatus.FILLED order.filled_at = datetime.now(timezone.utc) - logger.info("order_filled_dry_run", side=str(order.side), quantity=str(order.quantity), symbol=order.symbol) + logger.info( + "order_filled_dry_run", + side=str(order.side), + quantity=str(order.quantity), + symbol=order.symbol, + ) else: try: await self.exchange.create_order( @@ -88,7 +92,12 @@ class OrderExecutor: ) order.status = OrderStatus.FILLED order.filled_at = datetime.now(timezone.utc) - logger.info("order_filled", side=str(order.side), quantity=str(order.quantity), symbol=order.symbol) + logger.info( + "order_filled", + side=str(order.side), + quantity=str(order.quantity), + symbol=order.symbol, + ) except Exception as exc: order.status = OrderStatus.FAILED logger.error("order_failed", signal_id=str(signal.id), error=str(exc)) diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py index 7f0578d..ab6ef4f 100644 --- a/services/order-executor/src/order_executor/main.py +++ b/services/order-executor/src/order_executor/main.py @@ -1,4 +1,5 @@ """Order Executor Service entry point.""" + import asyncio from decimal import Decimal @@ -21,7 +22,9 @@ async def run() -> None: config = ExecutorConfig() log = setup_logging("order-executor", config.log_level, config.log_format) metrics = ServiceMetrics("order_executor") - notifier = TelegramNotifier(bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id) + notifier = TelegramNotifier( + bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + ) db = Database(config.database_url) await db.connect() @@ -69,12 +72,18 @@ async def run() -> None: event = Event.from_dict(msg) if event.type == EventType.SIGNAL: signal = event.data - log.info("processing_signal", signal_id=str(signal.id), symbol=signal.symbol) + log.info( + "processing_signal", signal_id=str(signal.id), symbol=signal.symbol + ) await executor.execute(signal) - metrics.events_processed.labels(service="order-executor", event_type="signal").inc() + metrics.events_processed.labels( + service="order-executor", event_type="signal" + ).inc() except Exception as exc: log.error("message_processing_failed", error=str(exc)) - metrics.errors_total.labels(service="order-executor", error_type="processing").inc() + metrics.errors_total.labels( + service="order-executor", error_type="processing" + ).inc() if messages: # Advance last_id to avoid re-reading — broker.read returns decoded dicts, # so we track progress by re-reading with "0" for replaying or "$" for new only. diff --git a/services/order-executor/src/order_executor/risk_manager.py b/services/order-executor/src/order_executor/risk_manager.py index 8e8a72c..db162e1 100644 --- a/services/order-executor/src/order_executor/risk_manager.py +++ b/services/order-executor/src/order_executor/risk_manager.py @@ -1,4 +1,5 @@ """Risk management for order execution.""" + from dataclasses import dataclass from decimal import Decimal @@ -49,7 +50,10 @@ class RiskManager: if position is not None: current_position_value = position.quantity * position.current_price - if balance > 0 and (current_position_value + order_cost) / balance > self.max_position_size: + if ( + balance > 0 + and (current_position_value + order_cost) / balance > self.max_position_size + ): return RiskCheckResult(allowed=False, reason="Position size exceeded") return RiskCheckResult(allowed=True, reason="OK") diff --git a/services/order-executor/tests/test_executor.py b/services/order-executor/tests/test_executor.py index 4836ffb..e64b6c0 100644 --- a/services/order-executor/tests/test_executor.py +++ b/services/order-executor/tests/test_executor.py @@ -1,4 +1,5 @@ """Tests for OrderExecutor.""" + from decimal import Decimal from unittest.mock import AsyncMock, MagicMock diff --git a/services/order-executor/tests/test_risk_manager.py b/services/order-executor/tests/test_risk_manager.py index f6b5545..a122d16 100644 --- a/services/order-executor/tests/test_risk_manager.py +++ b/services/order-executor/tests/test_risk_manager.py @@ -1,9 +1,9 @@ """Tests for RiskManager.""" + from decimal import Decimal -import pytest -from shared.models import OrderSide, Position, Signal +from shared.models import OrderSide, Signal from order_executor.risk_manager import RiskManager @@ -55,9 +55,7 @@ def test_risk_check_rejects_daily_loss_exceeded(): """Daily PnL of -1100 on 10000 balance = -11%, exceeding -10% limit.""" rm = make_risk_manager(daily_loss_limit_pct="10.0") signal = make_signal(side=OrderSide.BUY, price="100", quantity="0.1") - result = rm.check( - signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("-1100") - ) + result = rm.check(signal, balance=Decimal("10000"), positions={}, daily_pnl=Decimal("-1100")) assert result.allowed is False assert result.reason == "Daily loss limit exceeded" diff --git a/services/portfolio-manager/src/portfolio_manager/config.py b/services/portfolio-manager/src/portfolio_manager/config.py index bbd5049..eaf53fd 100644 --- a/services/portfolio-manager/src/portfolio_manager/config.py +++ b/services/portfolio-manager/src/portfolio_manager/config.py @@ -1,4 +1,5 @@ """Portfolio Manager configuration.""" + from shared.config import Settings diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py index 56624f7..a1c73be 100644 --- a/services/portfolio-manager/src/portfolio_manager/main.py +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -1,4 +1,5 @@ """Portfolio Manager Service entry point.""" + import asyncio from shared.broker import RedisBroker @@ -18,7 +19,9 @@ async def run() -> None: config = PortfolioConfig() log = setup_logging("portfolio-manager", config.log_level, config.log_format) metrics = ServiceMetrics("portfolio_manager") - notifier = TelegramNotifier(bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id) + notifier = TelegramNotifier( + bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + ) broker = RedisBroker(config.redis_url) tracker = PortfolioTracker() @@ -49,10 +52,14 @@ async def run() -> None: ) positions = tracker.get_all_positions() log.info("positions_updated", count=len(positions)) - metrics.events_processed.labels(service="portfolio-manager", event_type="order").inc() + metrics.events_processed.labels( + service="portfolio-manager", event_type="order" + ).inc() except Exception as exc: log.exception("message_processing_failed", error=str(exc)) - metrics.errors_total.labels(service="portfolio-manager", error_type="processing").inc() + metrics.errors_total.labels( + service="portfolio-manager", error_type="processing" + ).inc() # Update last_id to the latest processed message id if broker returns ids # Since broker.read returns parsed payloads (not ids), we use "$" to get new msgs except Exception as exc: diff --git a/services/portfolio-manager/src/portfolio_manager/pnl.py b/services/portfolio-manager/src/portfolio_manager/pnl.py index 96f0da8..4c0fa56 100644 --- a/services/portfolio-manager/src/portfolio_manager/pnl.py +++ b/services/portfolio-manager/src/portfolio_manager/pnl.py @@ -1,4 +1,5 @@ """PnL calculation functions for the portfolio manager.""" + from decimal import Decimal diff --git a/services/portfolio-manager/src/portfolio_manager/portfolio.py b/services/portfolio-manager/src/portfolio_manager/portfolio.py index 59106bb..2c93643 100644 --- a/services/portfolio-manager/src/portfolio_manager/portfolio.py +++ b/services/portfolio-manager/src/portfolio_manager/portfolio.py @@ -1,4 +1,5 @@ """Portfolio tracking for the portfolio manager service.""" + from decimal import Decimal from shared.models import Order, OrderSide, Position diff --git a/services/portfolio-manager/tests/test_pnl.py b/services/portfolio-manager/tests/test_pnl.py index 4462adc..5f5d807 100644 --- a/services/portfolio-manager/tests/test_pnl.py +++ b/services/portfolio-manager/tests/test_pnl.py @@ -1,4 +1,5 @@ """Tests for PnL calculation functions.""" + from decimal import Decimal from portfolio_manager.pnl import calculate_realized_pnl, calculate_unrealized_pnl diff --git a/services/portfolio-manager/tests/test_portfolio.py b/services/portfolio-manager/tests/test_portfolio.py index 26319ca..92ff6ca 100644 --- a/services/portfolio-manager/tests/test_portfolio.py +++ b/services/portfolio-manager/tests/test_portfolio.py @@ -1,4 +1,5 @@ """Tests for PortfolioTracker.""" + from decimal import Decimal from shared.models import Order, OrderSide, OrderStatus, OrderType diff --git a/services/strategy-engine/src/strategy_engine/config.py b/services/strategy-engine/src/strategy_engine/config.py index 2864b09..e3a49c2 100644 --- a/services/strategy-engine/src/strategy_engine/config.py +++ b/services/strategy-engine/src/strategy_engine/config.py @@ -1,4 +1,5 @@ """Strategy Engine configuration.""" + from shared.config import Settings diff --git a/services/strategy-engine/src/strategy_engine/engine.py b/services/strategy-engine/src/strategy_engine/engine.py index 09dbf65..d401aee 100644 --- a/services/strategy-engine/src/strategy_engine/engine.py +++ b/services/strategy-engine/src/strategy_engine/engine.py @@ -1,4 +1,5 @@ """Strategy Engine: consumes candle events and publishes signals.""" + import logging from shared.broker import RedisBroker @@ -36,9 +37,7 @@ class StrategyEngine: try: signal = strategy.on_candle(candle) except Exception as exc: - logger.error( - "Strategy %s raised on candle: %s", strategy.name, exc - ) + logger.error("Strategy %s raised on candle: %s", strategy.name, exc) continue if signal is not None: diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 2e3c4ac..53681d1 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -1,4 +1,5 @@ """Strategy Engine Service entry point.""" + import asyncio from pathlib import Path @@ -20,7 +21,9 @@ async def run() -> None: config = StrategyConfig() log = setup_logging("strategy-engine", config.log_level, config.log_format) metrics = ServiceMetrics("strategy_engine") - notifier = TelegramNotifier(bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id) + notifier = TelegramNotifier( + bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + ) broker = RedisBroker(config.redis_url) @@ -53,7 +56,9 @@ async def run() -> None: while True: last_id = await engine.process_once(stream, last_id) - metrics.events_processed.labels(service="strategy-engine", event_type="candle").inc() + metrics.events_processed.labels( + service="strategy-engine", event_type="candle" + ).inc() except Exception as exc: log.error("fatal_error", error=str(exc)) await notifier.send_error(str(exc), "strategy-engine") diff --git a/services/strategy-engine/src/strategy_engine/plugin_loader.py b/services/strategy-engine/src/strategy_engine/plugin_loader.py index f99b670..62e4160 100644 --- a/services/strategy-engine/src/strategy_engine/plugin_loader.py +++ b/services/strategy-engine/src/strategy_engine/plugin_loader.py @@ -1,4 +1,5 @@ """Dynamic plugin loader for strategy modules.""" + import importlib.util import sys from pathlib import Path @@ -29,11 +30,7 @@ def load_strategies(strategies_dir: Path) -> list[BaseStrategy]: for attr_name in dir(module): obj = getattr(module, attr_name) - if ( - isinstance(obj, type) - and issubclass(obj, BaseStrategy) - and obj is not BaseStrategy - ): + if isinstance(obj, type) and issubclass(obj, BaseStrategy) and obj is not BaseStrategy: instance = obj() yaml_path = config_dir / f"{path.stem}.yaml" if yaml_path.exists(): diff --git a/services/strategy-engine/strategies/volume_profile_strategy.py b/services/strategy-engine/strategies/volume_profile_strategy.py index 684c33c..e9463bf 100644 --- a/services/strategy-engine/strategies/volume_profile_strategy.py +++ b/services/strategy-engine/strategies/volume_profile_strategy.py @@ -39,9 +39,8 @@ class VolumeProfileStrategy(BaseStrategy): if len(data) < self._lookback_period: return None - recent = data[-self._lookback_period:] + recent = data[-self._lookback_period :] prices = np.array([c[0] for c in recent]) - volumes = np.array([c[1] for c in recent]) min_price = prices.min() max_price = prices.max() diff --git a/services/strategy-engine/tests/conftest.py b/services/strategy-engine/tests/conftest.py index c9ef308..eb31b23 100644 --- a/services/strategy-engine/tests/conftest.py +++ b/services/strategy-engine/tests/conftest.py @@ -1,4 +1,5 @@ """Pytest configuration: ensure strategies/ is importable.""" + import sys from pathlib import Path diff --git a/services/strategy-engine/tests/test_bollinger_strategy.py b/services/strategy-engine/tests/test_bollinger_strategy.py index b3d17ac..348a9e0 100644 --- a/services/strategy-engine/tests/test_bollinger_strategy.py +++ b/services/strategy-engine/tests/test_bollinger_strategy.py @@ -1,8 +1,8 @@ """Tests for the Bollinger Bands strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.bollinger_strategy import BollingerStrategy diff --git a/services/strategy-engine/tests/test_ema_crossover_strategy.py b/services/strategy-engine/tests/test_ema_crossover_strategy.py index 5a40319..0cf767b 100644 --- a/services/strategy-engine/tests/test_ema_crossover_strategy.py +++ b/services/strategy-engine/tests/test_ema_crossover_strategy.py @@ -1,8 +1,8 @@ """Tests for the EMA Crossover strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.ema_crossover_strategy import EmaCrossoverStrategy diff --git a/services/strategy-engine/tests/test_engine.py b/services/strategy-engine/tests/test_engine.py index 33ad4dd..ac9a596 100644 --- a/services/strategy-engine/tests/test_engine.py +++ b/services/strategy-engine/tests/test_engine.py @@ -1,4 +1,5 @@ """Tests for the StrategyEngine.""" + from datetime import datetime, timezone from decimal import Decimal from unittest.mock import AsyncMock, MagicMock @@ -6,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest from shared.models import Candle, Signal, OrderSide -from shared.events import CandleEvent, SignalEvent +from shared.events import CandleEvent from strategy_engine.engine import StrategyEngine diff --git a/services/strategy-engine/tests/test_grid_strategy.py b/services/strategy-engine/tests/test_grid_strategy.py index d96ebba..79eb22a 100644 --- a/services/strategy-engine/tests/test_grid_strategy.py +++ b/services/strategy-engine/tests/test_grid_strategy.py @@ -1,8 +1,8 @@ """Tests for the Grid strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.grid_strategy import GridStrategy @@ -23,12 +23,14 @@ def make_candle(close: float) -> Candle: def _configured_strategy() -> GridStrategy: strategy = GridStrategy() - strategy.configure({ - "lower_price": 48000, - "upper_price": 52000, - "grid_count": 5, - "quantity": "0.01", - }) + strategy.configure( + { + "lower_price": 48000, + "upper_price": 52000, + "grid_count": 5, + "quantity": "0.01", + } + ) return strategy diff --git a/services/strategy-engine/tests/test_macd_strategy.py b/services/strategy-engine/tests/test_macd_strategy.py index e1ae2a3..9931b43 100644 --- a/services/strategy-engine/tests/test_macd_strategy.py +++ b/services/strategy-engine/tests/test_macd_strategy.py @@ -1,8 +1,8 @@ """Tests for the MACD strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.macd_strategy import MacdStrategy diff --git a/services/strategy-engine/tests/test_plugin_loader.py b/services/strategy-engine/tests/test_plugin_loader.py index 9496bab..5191fc3 100644 --- a/services/strategy-engine/tests/test_plugin_loader.py +++ b/services/strategy-engine/tests/test_plugin_loader.py @@ -1,7 +1,7 @@ """Tests for the plugin loader.""" + from pathlib import Path -import pytest from strategy_engine.plugin_loader import load_strategies diff --git a/services/strategy-engine/tests/test_rsi_strategy.py b/services/strategy-engine/tests/test_rsi_strategy.py index 90fface..2a2f4e7 100644 --- a/services/strategy-engine/tests/test_rsi_strategy.py +++ b/services/strategy-engine/tests/test_rsi_strategy.py @@ -1,8 +1,8 @@ """Tests for the RSI strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.rsi_strategy import RsiStrategy diff --git a/services/strategy-engine/tests/test_volume_profile_strategy.py b/services/strategy-engine/tests/test_volume_profile_strategy.py index be123b0..71f0eca 100644 --- a/services/strategy-engine/tests/test_volume_profile_strategy.py +++ b/services/strategy-engine/tests/test_volume_profile_strategy.py @@ -1,8 +1,8 @@ """Tests for the Volume Profile strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.volume_profile_strategy import VolumeProfileStrategy @@ -39,17 +39,27 @@ def test_volume_profile_no_signal_insufficient_data(): def test_volume_profile_buy_at_value_area_low(): """Concentrate volume around 95-105, price drops to 88, bounces back to 99.""" strategy = VolumeProfileStrategy() - strategy.configure({ - "lookback_period": 10, - "num_bins": 5, - "value_area_pct": 0.7, - "quantity": "0.01", - }) + strategy.configure( + { + "lookback_period": 10, + "num_bins": 5, + "value_area_pct": 0.7, + "quantity": "0.01", + } + ) # Build profile: 10 candles with volume concentrated around 95-105 profile_data = [ - (95, 50), (97, 50), (99, 100), (100, 100), (101, 100), - (103, 50), (105, 50), (100, 100), (99, 100), (101, 50), + (95, 50), + (97, 50), + (99, 100), + (100, 100), + (101, 100), + (103, 50), + (105, 50), + (100, 100), + (99, 100), + (101, 50), ] for price, vol in profile_data: strategy.on_candle(make_candle(price, vol)) @@ -67,17 +77,27 @@ def test_volume_profile_buy_at_value_area_low(): def test_volume_profile_sell_at_value_area_high(): """Concentrate volume around 95-105, price rises to 112, pulls back to 101.""" strategy = VolumeProfileStrategy() - strategy.configure({ - "lookback_period": 10, - "num_bins": 5, - "value_area_pct": 0.7, - "quantity": "0.01", - }) + strategy.configure( + { + "lookback_period": 10, + "num_bins": 5, + "value_area_pct": 0.7, + "quantity": "0.01", + } + ) # Build profile: 10 candles with volume concentrated around 95-105 profile_data = [ - (95, 50), (97, 50), (99, 100), (100, 100), (101, 100), - (103, 50), (105, 50), (100, 100), (99, 100), (101, 50), + (95, 50), + (97, 50), + (99, 100), + (100, 100), + (101, 100), + (103, 50), + (105, 50), + (100, 100), + (99, 100), + (101, 50), ] for price, vol in profile_data: strategy.on_candle(make_candle(price, vol)) diff --git a/services/strategy-engine/tests/test_vwap_strategy.py b/services/strategy-engine/tests/test_vwap_strategy.py index 37d35bc..5d76b04 100644 --- a/services/strategy-engine/tests/test_vwap_strategy.py +++ b/services/strategy-engine/tests/test_vwap_strategy.py @@ -1,8 +1,8 @@ """Tests for the VWAP strategy.""" + from datetime import datetime, timezone from decimal import Decimal -import pytest from shared.models import Candle, OrderSide from strategies.vwap_strategy import VwapStrategy diff --git a/shared/src/shared/broker.py b/shared/src/shared/broker.py index 0f87b06..9c6c4c6 100644 --- a/shared/src/shared/broker.py +++ b/shared/src/shared/broker.py @@ -1,4 +1,5 @@ """Redis Streams broker for the trading platform.""" + import json from typing import Any @@ -24,9 +25,7 @@ class RedisBroker: block: int = 0, ) -> list[dict[str, Any]]: """Read messages from a Redis stream.""" - results = await self._redis.xread( - {stream: last_id}, count=count, block=block - ) + results = await self._redis.xread({stream: last_id}, count=count, block=block) messages = [] if results: for _stream, entries in results: diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index 511654f..47bc2b1 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -1,4 +1,5 @@ """Shared configuration settings for the trading platform.""" + from pydantic_settings import BaseSettings diff --git a/shared/src/shared/db.py b/shared/src/shared/db.py index 95e487e..f9b7f56 100644 --- a/shared/src/shared/db.py +++ b/shared/src/shared/db.py @@ -1,4 +1,5 @@ """Database layer using SQLAlchemy 2.0 async ORM for the trading platform.""" + from datetime import datetime from typing import Optional @@ -107,9 +108,7 @@ class Database: await session.execute(stmt) await session.commit() - async def get_candles( - self, symbol: str, timeframe: str, limit: int = 500 - ) -> list[dict]: + async def get_candles(self, symbol: str, timeframe: str, limit: int = 500) -> list[dict]: """Retrieve candles ordered by open_time descending.""" stmt = ( select(CandleRow) diff --git a/shared/src/shared/events.py b/shared/src/shared/events.py index 1db2bee..72f8865 100644 --- a/shared/src/shared/events.py +++ b/shared/src/shared/events.py @@ -1,4 +1,5 @@ """Event types and serialization for the trading platform.""" + from enum import Enum from typing import Any diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py index 8294294..be02712 100644 --- a/shared/src/shared/healthcheck.py +++ b/shared/src/shared/healthcheck.py @@ -1,4 +1,5 @@ """Health check HTTP server with Prometheus metrics endpoint.""" + from __future__ import annotations import time diff --git a/shared/src/shared/logging.py b/shared/src/shared/logging.py index b873eaf..9e42cdc 100644 --- a/shared/src/shared/logging.py +++ b/shared/src/shared/logging.py @@ -1,4 +1,5 @@ """Structured logging configuration using structlog.""" + from __future__ import annotations import logging diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py index 3b00c5d..cd239f3 100644 --- a/shared/src/shared/metrics.py +++ b/shared/src/shared/metrics.py @@ -1,4 +1,5 @@ """Prometheus metrics for trading platform services.""" + from __future__ import annotations from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY diff --git a/shared/src/shared/models.py b/shared/src/shared/models.py index 4cb1081..0e8ca44 100644 --- a/shared/src/shared/models.py +++ b/shared/src/shared/models.py @@ -1,4 +1,5 @@ """Shared Pydantic models for the trading platform.""" + import uuid from decimal import Decimal from datetime import datetime, timezone diff --git a/shared/src/shared/notifier.py b/shared/src/shared/notifier.py index de86f87..f03919c 100644 --- a/shared/src/shared/notifier.py +++ b/shared/src/shared/notifier.py @@ -1,4 +1,5 @@ """Telegram notification service for the trading platform.""" + import asyncio import logging from decimal import Decimal @@ -63,9 +64,7 @@ class TelegramNotifier: body, ) except Exception: - logger.exception( - "Telegram send failed (attempt %d/%d)", attempt, MAX_RETRIES - ) + logger.exception("Telegram send failed (attempt %d/%d)", attempt, MAX_RETRIES) if attempt < MAX_RETRIES: await asyncio.sleep(attempt) @@ -96,11 +95,7 @@ class TelegramNotifier: async def send_error(self, error: str, service: str) -> None: """Format and send an error alert.""" - msg = ( - "<b>🚨 Error Alert</b>\n" - f"Service: <b>{service}</b>\n" - f"Error: {error}" - ) + msg = f"<b>🚨 Error Alert</b>\nService: <b>{service}</b>\nError: {error}" await self.send(msg) async def send_daily_summary( diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py index d4e963b..e43fd21 100644 --- a/shared/src/shared/resilience.py +++ b/shared/src/shared/resilience.py @@ -35,7 +35,7 @@ def retry_with_backoff( except Exception as exc: last_exc = exc if attempt < max_retries: - delay = min(base_delay * (2 ** attempt), max_delay) + delay = min(base_delay * (2**attempt), max_delay) jitter = delay * random.uniform(0, 0.5) total_delay = delay + jitter logger.warning( diff --git a/shared/src/shared/sa_models.py b/shared/src/shared/sa_models.py index 0537846..8386ba8 100644 --- a/shared/src/shared/sa_models.py +++ b/shared/src/shared/sa_models.py @@ -41,17 +41,13 @@ class OrderRow(Base): __tablename__ = "orders" id: Mapped[str] = mapped_column(Text, primary_key=True) - signal_id: Mapped[str | None] = mapped_column( - Text, ForeignKey("signals.id") - ) + signal_id: Mapped[str | None] = mapped_column(Text, ForeignKey("signals.id")) symbol: Mapped[str] = mapped_column(Text, nullable=False) side: Mapped[str] = mapped_column(Text, nullable=False) type: Mapped[str] = mapped_column(Text, nullable=False) price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) - status: Mapped[str] = mapped_column( - Text, nullable=False, server_default="PENDING" - ) + status: Mapped[str] = mapped_column(Text, nullable=False, server_default="PENDING") created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) filled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) @@ -60,16 +56,12 @@ class TradeRow(Base): __tablename__ = "trades" id: Mapped[str] = mapped_column(Text, primary_key=True) - order_id: Mapped[str | None] = mapped_column( - Text, ForeignKey("orders.id") - ) + order_id: Mapped[str | None] = mapped_column(Text, ForeignKey("orders.id")) symbol: Mapped[str] = mapped_column(Text, nullable=False) side: Mapped[str] = mapped_column(Text, nullable=False) price: Mapped[Decimal] = mapped_column(Numeric, nullable=False) quantity: Mapped[Decimal] = mapped_column(Numeric, nullable=False) - fee: Mapped[Decimal] = mapped_column( - Numeric, nullable=False, server_default="0" - ) + fee: Mapped[Decimal] = mapped_column(Numeric, nullable=False, server_default="0") traded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py index d3a3569..ea8b148 100644 --- a/shared/tests/test_broker.py +++ b/shared/tests/test_broker.py @@ -1,7 +1,8 @@ """Tests for the Redis broker.""" + import pytest import json -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch @pytest.mark.asyncio @@ -12,6 +13,7 @@ async def test_broker_publish(): mock_from_url.return_value = mock_redis from shared.broker import RedisBroker + broker = RedisBroker("redis://localhost:6379") data = {"type": "CANDLE", "symbol": "BTCUSDT"} await broker.publish("candles", data) @@ -43,6 +45,7 @@ async def test_broker_subscribe_returns_messages(): ] from shared.broker import RedisBroker + broker = RedisBroker("redis://localhost:6379") messages = await broker.read("candles", last_id="$") @@ -60,6 +63,7 @@ async def test_broker_close(): mock_from_url.return_value = mock_redis from shared.broker import RedisBroker + broker = RedisBroker("redis://localhost:6379") await broker.close() diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index 45d5dcd..b9a9d56 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -1,12 +1,14 @@ """Tests for the SQLAlchemy async database layer.""" + import pytest from decimal import Decimal from datetime import datetime, timezone -from unittest.mock import AsyncMock, MagicMock, patch, call +from unittest.mock import AsyncMock, MagicMock, patch def make_candle(): from shared.models import Candle + return Candle( symbol="BTCUSDT", timeframe="1m", @@ -21,6 +23,7 @@ def make_candle(): def make_signal(): from shared.models import Signal, OrderSide + return Signal( id="sig-1", strategy="ma_cross", @@ -35,6 +38,7 @@ def make_signal(): def make_order(): from shared.models import Order, OrderSide, OrderType, OrderStatus + return Order( id="ord-1", signal_id="sig-1", @@ -51,21 +55,25 @@ def make_order(): class TestDatabaseConstructor: def test_stores_url(self): from shared.db import Database + db = Database("postgresql://user:pass@localhost/db") assert db._database_url == "postgresql+asyncpg://user:pass@localhost/db" def test_converts_url_prefix(self): from shared.db import Database + db = Database("postgresql://host/db") assert db._database_url.startswith("postgresql+asyncpg://") def test_keeps_asyncpg_prefix(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") assert db._database_url == "postgresql+asyncpg://host/db" def test_get_session_exists(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") assert hasattr(db, "get_session") @@ -74,6 +82,7 @@ class TestDatabaseConnect: @pytest.mark.asyncio async def test_connect_creates_engine_and_tables(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") mock_conn = AsyncMock() @@ -94,6 +103,7 @@ class TestDatabaseConnect: @pytest.mark.asyncio async def test_init_tables_is_alias_for_connect(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") mock_conn = AsyncMock() @@ -116,6 +126,7 @@ class TestDatabaseClose: @pytest.mark.asyncio async def test_close_disposes_engine(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") mock_engine = AsyncMock() db._engine = mock_engine @@ -127,6 +138,7 @@ class TestInsertCandle: @pytest.mark.asyncio async def test_insert_candle_uses_merge_and_commit(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() @@ -147,6 +159,7 @@ class TestInsertSignal: @pytest.mark.asyncio async def test_insert_signal_uses_add_and_commit(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() @@ -167,6 +180,7 @@ class TestInsertOrder: @pytest.mark.asyncio async def test_insert_order_uses_add_and_commit(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() @@ -188,6 +202,7 @@ class TestUpdateOrderStatus: async def test_update_order_status_uses_execute_and_commit(self): from shared.db import Database from shared.models import OrderStatus + db = Database("postgresql+asyncpg://host/db") mock_session = AsyncMock() @@ -207,6 +222,7 @@ class TestGetCandles: @pytest.mark.asyncio async def test_get_candles_returns_list_of_dicts(self): from shared.db import Database + db = Database("postgresql+asyncpg://host/db") # Create a mock row that behaves like a SA result row diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py index 4bc7981..ab7792b 100644 --- a/shared/tests/test_events.py +++ b/shared/tests/test_events.py @@ -1,11 +1,12 @@ """Tests for shared event types.""" -import pytest + from decimal import Decimal from datetime import datetime, timezone def make_candle(): from shared.models import Candle + return Candle( symbol="BTCUSDT", timeframe="1m", @@ -20,6 +21,7 @@ def make_candle(): def make_signal(): from shared.models import Signal, OrderSide + return Signal( strategy="test", symbol="BTCUSDT", @@ -33,6 +35,7 @@ def make_signal(): def test_candle_event_serialize(): """Test CandleEvent serializes to dict correctly.""" from shared.events import CandleEvent, EventType + candle = make_candle() event = CandleEvent(data=candle) d = event.to_dict() @@ -44,6 +47,7 @@ def test_candle_event_serialize(): def test_candle_event_deserialize(): """Test CandleEvent round-trips through to_dict/from_raw.""" from shared.events import CandleEvent, EventType + candle = make_candle() event = CandleEvent(data=candle) d = event.to_dict() @@ -56,6 +60,7 @@ def test_candle_event_deserialize(): def test_signal_event_serialize(): """Test SignalEvent serializes to dict correctly.""" from shared.events import SignalEvent, EventType + signal = make_signal() event = SignalEvent(data=signal) d = event.to_dict() @@ -66,7 +71,8 @@ def test_signal_event_serialize(): def test_event_from_dict_dispatch(): """Test Event.from_dict dispatches to correct class.""" - from shared.events import Event, CandleEvent, SignalEvent, EventType + from shared.events import Event, CandleEvent, SignalEvent + candle = make_candle() event = CandleEvent(data=candle) d = event.to_dict() diff --git a/shared/tests/test_healthcheck.py b/shared/tests/test_healthcheck.py index 1af86b1..6970a8f 100644 --- a/shared/tests/test_healthcheck.py +++ b/shared/tests/test_healthcheck.py @@ -1,6 +1,6 @@ """Tests for health check server.""" + import pytest -import asyncio from prometheus_client import CollectorRegistry @@ -11,6 +11,7 @@ def registry(): def make_server(service_name="test-service", port=8080, registry=None): from shared.healthcheck import HealthCheckServer + return HealthCheckServer(service_name, port=port, registry=registry) diff --git a/shared/tests/test_logging.py b/shared/tests/test_logging.py index 4abd254..2ffddcd 100644 --- a/shared/tests/test_logging.py +++ b/shared/tests/test_logging.py @@ -1,4 +1,5 @@ """Tests for shared structured logging module.""" + import io import json import logging diff --git a/shared/tests/test_metrics.py b/shared/tests/test_metrics.py index 079f01c..3fd72a7 100644 --- a/shared/tests/test_metrics.py +++ b/shared/tests/test_metrics.py @@ -1,10 +1,11 @@ """Tests for Prometheus metrics utilities.""" -import pytest + from prometheus_client import CollectorRegistry def make_metrics(service_name="test-service", registry=None): from shared.metrics import ServiceMetrics + return ServiceMetrics(service_name, registry=registry) diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index f1d92ec..25ab4c9 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -1,6 +1,6 @@ """Tests for shared models and settings.""" + import os -import pytest from decimal import Decimal from datetime import datetime, timezone from unittest.mock import patch @@ -8,11 +8,15 @@ from unittest.mock import patch def test_settings_defaults(): """Test that Settings has correct defaults.""" - with patch.dict(os.environ, { - "BINANCE_API_KEY": "test_key", - "BINANCE_API_SECRET": "test_secret", - }): + with patch.dict( + os.environ, + { + "BINANCE_API_KEY": "test_key", + "BINANCE_API_SECRET": "test_secret", + }, + ): from shared.config import Settings + settings = Settings() assert settings.redis_url == "redis://localhost:6379" assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" @@ -26,6 +30,7 @@ def test_settings_defaults(): def test_candle_creation(): """Test Candle model creation.""" from shared.models import Candle + now = datetime.now(timezone.utc) candle = Candle( symbol="BTCUSDT", @@ -49,6 +54,7 @@ def test_candle_creation(): def test_signal_creation(): """Test Signal model creation.""" from shared.models import Signal, OrderSide + signal = Signal( strategy="rsi_strategy", symbol="BTCUSDT", @@ -71,6 +77,7 @@ def test_order_creation(): """Test Order model creation with defaults.""" from shared.models import Order, OrderSide, OrderType, OrderStatus import uuid + signal_id = str(uuid.uuid4()) order = Order( signal_id=signal_id, @@ -90,6 +97,7 @@ def test_order_creation(): def test_position_unrealized_pnl(): """Test Position unrealized_pnl computed property.""" from shared.models import Position + position = Position( symbol="BTCUSDT", quantity=Decimal("0.1"), diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py index 09e731a..3d29830 100644 --- a/shared/tests/test_notifier.py +++ b/shared/tests/test_notifier.py @@ -1,12 +1,10 @@ """Tests for Telegram notification service.""" -import os + import uuid from decimal import Decimal -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest -import pytest_asyncio from shared.models import Signal, Order, OrderSide, OrderType, OrderStatus, Position from shared.notifier import TelegramNotifier diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index 514bcc2..e287777 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -1,5 +1,5 @@ """Tests for retry with backoff and circuit breaker.""" -import asyncio + import time import pytest @@ -63,6 +63,7 @@ async def test_retry_raises_after_max_retries(): @pytest.mark.asyncio async def test_retry_respects_max_delay(): """Backoff should be capped at max_delay.""" + @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) async def always_fail(): raise RuntimeError("fail") diff --git a/shared/tests/test_sa_models.py b/shared/tests/test_sa_models.py index de994c5..67c3c82 100644 --- a/shared/tests/test_sa_models.py +++ b/shared/tests/test_sa_models.py @@ -1,6 +1,5 @@ """Tests for SQLAlchemy ORM models.""" -import pytest from sqlalchemy import inspect @@ -117,9 +116,7 @@ class TestOrderRow: from shared.sa_models import OrderRow table = OrderRow.__table__ - fk_cols = { - fk.parent.name: fk.target_fullname for fk in table.foreign_keys - } + fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys} assert fk_cols == {"signal_id": "signals.id"} @@ -157,9 +154,7 @@ class TestTradeRow: from shared.sa_models import TradeRow table = TradeRow.__table__ - fk_cols = { - fk.parent.name: fk.target_fullname for fk in table.foreign_keys - } + fk_cols = {fk.parent.name: fk.target_fullname for fk in table.foreign_keys} assert fk_cols == {"order_id": "orders.id"} |
