diff options
Diffstat (limited to 'shared')
| -rw-r--r-- | shared/pyproject.toml | 3 | ||||
| -rw-r--r-- | shared/src/shared/alpaca.py | 190 | ||||
| -rw-r--r-- | shared/src/shared/config.py | 19 | ||||
| -rw-r--r-- | shared/src/shared/exchange.py | 50 | ||||
| -rw-r--r-- | shared/src/shared/sentiment.py | 35 | ||||
| -rw-r--r-- | shared/tests/test_alpaca.py | 69 | ||||
| -rw-r--r-- | shared/tests/test_broker.py | 6 | ||||
| -rw-r--r-- | shared/tests/test_db.py | 12 | ||||
| -rw-r--r-- | shared/tests/test_events.py | 10 | ||||
| -rw-r--r-- | shared/tests/test_exchange.py | 55 | ||||
| -rw-r--r-- | shared/tests/test_models.py | 38 | ||||
| -rw-r--r-- | shared/tests/test_notifier.py | 12 | ||||
| -rw-r--r-- | shared/tests/test_sentiment.py | 44 |
13 files changed, 394 insertions, 149 deletions
diff --git a/shared/pyproject.toml b/shared/pyproject.toml index c36f00b..830088d 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -23,6 +23,9 @@ dev = [ "pytest-asyncio>=0.23", "ruff>=0.4", ] +claude = [ + "anthropic>=0.40", +] [build-system] requires = ["hatchling"] diff --git a/shared/src/shared/alpaca.py b/shared/src/shared/alpaca.py new file mode 100644 index 0000000..7821592 --- /dev/null +++ b/shared/src/shared/alpaca.py @@ -0,0 +1,190 @@ +"""Alpaca Markets API client for US stock trading.""" + +import logging +from decimal import Decimal +from typing import Any + +import aiohttp + +logger = logging.getLogger(__name__) + +ALPACA_PAPER_URL = "https://paper-api.alpaca.markets" +ALPACA_LIVE_URL = "https://api.alpaca.markets" +ALPACA_DATA_URL = "https://data.alpaca.markets" + + +class AlpacaClient: + """Async client for Alpaca Trading and Market Data APIs.""" + + def __init__( + self, + api_key: str, + api_secret: str, + paper: bool = True, + ) -> None: + self._api_key = api_key + self._api_secret = api_secret + self._base_url = ALPACA_PAPER_URL if paper else ALPACA_LIVE_URL + self._data_url = ALPACA_DATA_URL + self._session: aiohttp.ClientSession | None = None + + @property + def headers(self) -> dict[str, str]: + return { + "APCA-API-KEY-ID": self._api_key, + "APCA-API-SECRET-KEY": self._api_secret, + } + + async def _ensure_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession(headers=self.headers) + return self._session + + async def _request(self, method: str, url: str, **kwargs) -> dict | list: + session = await self._ensure_session() + async with session.request(method, url, **kwargs) as resp: + if resp.status >= 400: + body = await resp.text() + logger.error("Alpaca API error %d: %s", resp.status, body) + raise RuntimeError(f"Alpaca API error {resp.status}: {body}") + return await resp.json() + + # --- Account --- + async def get_account(self) -> dict: + return await self._request("GET", f"{self._base_url}/v2/account") + + async def get_buying_power(self) -> Decimal: + account = await self.get_account() + return Decimal(str(account.get("buying_power", "0"))) + + # --- Orders --- + async def submit_order( + self, + symbol: str, + qty: float | None = None, + notional: float | None = None, + side: str = "buy", + type: str = "market", + time_in_force: str = "day", + ) -> dict: + """Submit an order. + + Args: + symbol: Stock ticker (e.g., "AAPL") + qty: Number of shares (or use notional for dollar amount) + notional: Dollar amount to buy (fractional shares) + side: "buy" or "sell" + type: "market", "limit", "stop", "stop_limit", "market_on_close" + time_in_force: "day", "gtc", "opg", "cls" + """ + data: dict[str, Any] = { + "symbol": symbol, + "side": side, + "type": type, + "time_in_force": time_in_force, + } + if qty is not None: + data["qty"] = str(qty) + elif notional is not None: + data["notional"] = str(notional) + + return await self._request("POST", f"{self._base_url}/v2/orders", json=data) + + async def submit_moc_order( + self, + symbol: str, + qty: float, + side: str = "buy", + ) -> dict: + """Submit a Market on Close order.""" + return await self.submit_order( + symbol=symbol, + qty=qty, + side=side, + type="market", + time_in_force="cls", + ) + + async def get_orders(self, status: str = "open", limit: int = 50) -> list: + return await self._request( + "GET", + f"{self._base_url}/v2/orders", + params={"status": status, "limit": limit}, + ) + + async def cancel_all_orders(self) -> list: + return await self._request("DELETE", f"{self._base_url}/v2/orders") + + # --- Positions --- + async def get_positions(self) -> list: + return await self._request("GET", f"{self._base_url}/v2/positions") + + async def get_position(self, symbol: str) -> dict | None: + try: + return await self._request("GET", f"{self._base_url}/v2/positions/{symbol}") + except RuntimeError: + return None + + async def close_position(self, symbol: str) -> dict: + return await self._request("DELETE", f"{self._base_url}/v2/positions/{symbol}") + + async def close_all_positions(self) -> list: + return await self._request("DELETE", f"{self._base_url}/v2/positions") + + # --- Market Data --- + async def get_bars( + self, + symbol: str, + timeframe: str = "1Day", + start: str | None = None, + end: str | None = None, + limit: int = 100, + ) -> list[dict]: + """Get historical bars. + + Args: + symbol: Stock ticker + timeframe: "1Min", "5Min", "15Min", "1Hour", "1Day" + start: RFC3339 date string + end: RFC3339 date string + limit: Max bars to return + """ + params: dict[str, Any] = {"timeframe": timeframe, "limit": limit} + if start: + params["start"] = start + if end: + params["end"] = end + + data = await self._request( + "GET", + f"{self._data_url}/v2/stocks/{symbol}/bars", + params=params, + ) + return data.get("bars", []) + + async def get_latest_quote(self, symbol: str) -> dict: + data = await self._request( + "GET", + f"{self._data_url}/v2/stocks/{symbol}/quotes/latest", + ) + return data.get("quote", {}) + + async def get_snapshot(self, symbol: str) -> dict: + return await self._request( + "GET", + f"{self._data_url}/v2/stocks/{symbol}/snapshot", + ) + + # --- Market Status --- + async def get_clock(self) -> dict: + """Get market clock (is_open, next_open, next_close).""" + return await self._request("GET", f"{self._base_url}/v2/clock") + + async def is_market_open(self) -> bool: + clock = await self.get_clock() + return clock.get("is_open", False) + + # --- Cleanup --- + async def close(self) -> None: + if self._session and not self._session.closed: + await self._session.close() diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index ab0331c..4e8e7f1 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -4,12 +4,11 @@ from pydantic_settings import BaseSettings class Settings(BaseSettings): - binance_api_key: str - binance_api_secret: str + alpaca_api_key: str = "" + alpaca_api_secret: str = "" + alpaca_paper: bool = True # Use paper trading by default redis_url: str = "redis://localhost:6379" database_url: str = "postgresql://trading:trading@localhost:5432/trading" - exchange_id: str = "binance" # Any ccxt exchange ID - exchange_sandbox: bool = False # Use sandbox/testnet mode log_level: str = "INFO" risk_max_position_size: float = 0.1 risk_stop_loss_pct: float = 5.0 @@ -18,6 +17,15 @@ class Settings(BaseSettings): risk_max_open_positions: int = 10 risk_volatility_lookback: int = 20 risk_volatility_scale: bool = False + risk_max_portfolio_exposure: float = 0.8 + risk_max_correlated_exposure: float = 0.5 + risk_correlation_threshold: float = 0.7 + risk_var_confidence: float = 0.95 + risk_var_limit_pct: float = 5.0 + risk_drawdown_reduction_threshold: float = 0.1 + risk_drawdown_halt_threshold: float = 0.2 + risk_max_consecutive_losses: int = 5 + risk_loss_pause_minutes: int = 60 dry_run: bool = True telegram_bot_token: str = "" telegram_chat_id: str = "" @@ -27,5 +35,4 @@ class Settings(BaseSettings): circuit_breaker_threshold: int = 5 circuit_breaker_timeout: int = 60 metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token - - model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} + model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} diff --git a/shared/src/shared/exchange.py b/shared/src/shared/exchange.py deleted file mode 100644 index 7afcd92..0000000 --- a/shared/src/shared/exchange.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Exchange factory using ccxt.""" - -import ccxt.async_support as ccxt - - -def create_exchange( - exchange_id: str, - api_key: str, - api_secret: str, - sandbox: bool = False, -) -> ccxt.Exchange: - """Create a ccxt async exchange instance by ID. - - Args: - exchange_id: ccxt exchange ID (e.g. 'binance', 'bybit', 'okx', 'kraken') - api_key: API key - api_secret: API secret - sandbox: Use sandbox/testnet mode - - Returns: - Configured ccxt async exchange instance - - Raises: - ValueError: If exchange_id is not supported by ccxt - """ - if not hasattr(ccxt, exchange_id): - available = [ - x - for x in dir(ccxt) - if not x.startswith("_") - and isinstance(getattr(ccxt, x, None), type) - and issubclass(getattr(ccxt, x), ccxt.Exchange) - ] - raise ValueError( - f"Unknown exchange '{exchange_id}'. Available: {', '.join(sorted(available)[:20])}..." - ) - - exchange_cls = getattr(ccxt, exchange_id) - exchange = exchange_cls( - { - "apiKey": api_key, - "secret": api_secret, - "enableRateLimit": True, - } - ) - - if sandbox: - exchange.set_sandbox_mode(True) - - return exchange diff --git a/shared/src/shared/sentiment.py b/shared/src/shared/sentiment.py new file mode 100644 index 0000000..8213b47 --- /dev/null +++ b/shared/src/shared/sentiment.py @@ -0,0 +1,35 @@ +"""Market sentiment data.""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) + + +@dataclass +class SentimentData: + """Aggregated sentiment snapshot.""" + + fear_greed_value: int | None = None + fear_greed_label: str | None = None + news_sentiment: float | None = None + news_count: int = 0 + exchange_netflow: float | None = None + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + @property + def should_buy(self) -> bool: + if self.fear_greed_value is not None and self.fear_greed_value > 70: + return False + if self.news_sentiment is not None and self.news_sentiment < -0.3: + return False + return True + + @property + def should_block(self) -> bool: + if self.fear_greed_value is not None and self.fear_greed_value > 80: + return True + if self.news_sentiment is not None and self.news_sentiment < -0.5: + return True + return False diff --git a/shared/tests/test_alpaca.py b/shared/tests/test_alpaca.py new file mode 100644 index 0000000..080b7c4 --- /dev/null +++ b/shared/tests/test_alpaca.py @@ -0,0 +1,69 @@ +"""Tests for Alpaca API client.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from shared.alpaca import AlpacaClient + + +@pytest.fixture +def client(): + return AlpacaClient(api_key="test-key", api_secret="test-secret", paper=True) + + +def test_client_uses_paper_url(client): + assert "paper" in client._base_url + + +def test_client_uses_live_url(): + c = AlpacaClient(api_key="k", api_secret="s", paper=False) + assert "paper" not in c._base_url + + +def test_client_headers(client): + h = client.headers + assert h["APCA-API-KEY-ID"] == "test-key" + assert h["APCA-API-SECRET-KEY"] == "test-secret" + + +@pytest.mark.asyncio +async def test_get_buying_power(client): + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"buying_power": "10000.00"}) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.closed = False + mock_session.request = MagicMock(return_value=mock_response) + mock_session.close = AsyncMock() + client._session = mock_session + + result = await client.get_buying_power() + from decimal import Decimal + + assert result == Decimal("10000.00") + await client.close() + + +@pytest.mark.asyncio +async def test_submit_moc_order(client): + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"id": "order-1", "status": "accepted"}) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.closed = False + mock_session.request = MagicMock(return_value=mock_response) + mock_session.close = AsyncMock() + client._session = mock_session + + result = await client.submit_moc_order("AAPL", qty=10, side="buy") + assert result["id"] == "order-1" + + # Verify the request was made with correct params + call_args = mock_session.request.call_args + assert call_args[0][0] == "POST" + await client.close() diff --git a/shared/tests/test_broker.py b/shared/tests/test_broker.py index 9be84b0..eb1582d 100644 --- a/shared/tests/test_broker.py +++ b/shared/tests/test_broker.py @@ -16,7 +16,7 @@ async def test_broker_publish(): from shared.broker import RedisBroker broker = RedisBroker("redis://localhost:6379") - data = {"type": "CANDLE", "symbol": "BTCUSDT"} + data = {"type": "CANDLE", "symbol": "AAPL"} await broker.publish("candles", data) mock_redis.xadd.assert_called_once() @@ -35,7 +35,7 @@ async def test_broker_subscribe_returns_messages(): mock_redis = AsyncMock() mock_from_url.return_value = mock_redis - payload_data = {"type": "CANDLE", "symbol": "ETHUSDT"} + payload_data = {"type": "CANDLE", "symbol": "MSFT"} mock_redis.xread.return_value = [ [ b"candles", @@ -53,7 +53,7 @@ async def test_broker_subscribe_returns_messages(): mock_redis.xread.assert_called_once() assert len(messages) == 1 assert messages[0]["type"] == "CANDLE" - assert messages[0]["symbol"] == "ETHUSDT" + assert messages[0]["symbol"] == "MSFT" @pytest.mark.asyncio diff --git a/shared/tests/test_db.py b/shared/tests/test_db.py index d33dfe1..239ee64 100644 --- a/shared/tests/test_db.py +++ b/shared/tests/test_db.py @@ -10,7 +10,7 @@ def make_candle(): from shared.models import Candle return Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal("50000"), @@ -27,7 +27,7 @@ def make_signal(): return Signal( id="sig-1", strategy="ma_cross", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000"), quantity=Decimal("0.1"), @@ -42,7 +42,7 @@ def make_order(): return Order( id="ord-1", signal_id="sig-1", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.LIMIT, price=Decimal("50000"), @@ -228,7 +228,7 @@ class TestGetCandles: # Create a mock row that behaves like a SA result row mock_row = MagicMock() mock_row._mapping = { - "symbol": "BTCUSDT", + "symbol": "AAPL", "timeframe": "1m", "open_time": datetime(2024, 1, 1, tzinfo=timezone.utc), "open": Decimal("50000"), @@ -248,11 +248,11 @@ class TestGetCandles: db._session_factory = MagicMock(return_value=mock_session) - result = await db.get_candles("BTCUSDT", "1m", 500) + result = await db.get_candles("AAPL", "1m", 500) assert isinstance(result, list) assert len(result) == 1 - assert result[0]["symbol"] == "BTCUSDT" + assert result[0]["symbol"] == "AAPL" mock_session.execute.assert_awaited_once() diff --git a/shared/tests/test_events.py b/shared/tests/test_events.py index ab7792b..6077d93 100644 --- a/shared/tests/test_events.py +++ b/shared/tests/test_events.py @@ -8,7 +8,7 @@ def make_candle(): from shared.models import Candle return Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1m", open_time=datetime(2024, 1, 1, tzinfo=timezone.utc), open=Decimal("50000"), @@ -24,7 +24,7 @@ def make_signal(): return Signal( strategy="test", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000"), quantity=Decimal("0.01"), @@ -40,7 +40,7 @@ def test_candle_event_serialize(): event = CandleEvent(data=candle) d = event.to_dict() assert d["type"] == EventType.CANDLE - assert d["data"]["symbol"] == "BTCUSDT" + assert d["data"]["symbol"] == "AAPL" assert d["data"]["timeframe"] == "1m" @@ -53,7 +53,7 @@ def test_candle_event_deserialize(): d = event.to_dict() restored = CandleEvent.from_raw(d) assert restored.type == EventType.CANDLE - assert restored.data.symbol == "BTCUSDT" + assert restored.data.symbol == "AAPL" assert restored.data.close == Decimal("50500") @@ -65,7 +65,7 @@ def test_signal_event_serialize(): event = SignalEvent(data=signal) d = event.to_dict() assert d["type"] == EventType.SIGNAL - assert d["data"]["symbol"] == "BTCUSDT" + assert d["data"]["symbol"] == "AAPL" assert d["data"]["strategy"] == "test" diff --git a/shared/tests/test_exchange.py b/shared/tests/test_exchange.py deleted file mode 100644 index 95dc7d7..0000000 --- a/shared/tests/test_exchange.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Tests for the exchange factory.""" - -from unittest.mock import patch - -import ccxt.async_support as ccxt -import pytest - -from shared.exchange import create_exchange - - -def test_create_exchange_binance(): - """Verify create_exchange returns a ccxt.binance instance.""" - exchange = create_exchange( - exchange_id="binance", - api_key="test-key", - api_secret="test-secret", - ) - assert isinstance(exchange, ccxt.binance) - assert exchange.apiKey == "test-key" - assert exchange.secret == "test-secret" - assert exchange.enableRateLimit is True - - -def test_create_exchange_unknown(): - """Verify create_exchange raises ValueError for unknown exchange.""" - with pytest.raises(ValueError, match="Unknown exchange 'not_a_real_exchange'"): - create_exchange( - exchange_id="not_a_real_exchange", - api_key="key", - api_secret="secret", - ) - - -def test_create_exchange_with_sandbox(): - """Verify sandbox mode is activated when sandbox=True.""" - with patch.object(ccxt.binance, "set_sandbox_mode") as mock_sandbox: - exchange = create_exchange( - exchange_id="binance", - api_key="key", - api_secret="secret", - sandbox=True, - ) - mock_sandbox.assert_called_once_with(True) - assert isinstance(exchange, ccxt.binance) - - -def test_create_exchange_no_sandbox_by_default(): - """Verify sandbox mode is not set when sandbox=False (default).""" - with patch.object(ccxt.binance, "set_sandbox_mode") as mock_sandbox: - create_exchange( - exchange_id="binance", - api_key="key", - api_secret="secret", - ) - mock_sandbox.assert_not_called() diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index b23d71d..04098ce 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -8,15 +8,9 @@ from unittest.mock import patch def test_settings_defaults(): """Test that Settings has correct defaults.""" - with patch.dict( - os.environ, - { - "BINANCE_API_KEY": "test_key", - "BINANCE_API_SECRET": "test_secret", - }, - ): - from shared.config import Settings + from shared.config import Settings + with patch.dict(os.environ, {}, clear=False): settings = Settings() assert settings.redis_url == "redis://localhost:6379" assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" @@ -33,7 +27,7 @@ def test_candle_creation(): now = datetime.now(timezone.utc) candle = Candle( - symbol="BTCUSDT", + symbol="AAPL", timeframe="1m", open_time=now, open=Decimal("50000.00"), @@ -42,7 +36,7 @@ def test_candle_creation(): close=Decimal("50500.00"), volume=Decimal("100.5"), ) - assert candle.symbol == "BTCUSDT" + assert candle.symbol == "AAPL" assert candle.timeframe == "1m" assert candle.open == Decimal("50000.00") assert candle.high == Decimal("51000.00") @@ -57,14 +51,14 @@ def test_signal_creation(): signal = Signal( strategy="rsi_strategy", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000.00"), quantity=Decimal("0.01"), reason="RSI oversold", ) assert signal.strategy == "rsi_strategy" - assert signal.symbol == "BTCUSDT" + assert signal.symbol == "AAPL" assert signal.side == OrderSide.BUY assert signal.price == Decimal("50000.00") assert signal.quantity == Decimal("0.01") @@ -81,7 +75,7 @@ def test_order_creation(): signal_id = str(uuid.uuid4()) order = Order( signal_id=signal_id, - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, type=OrderType.MARKET, price=Decimal("50000.00"), @@ -99,8 +93,12 @@ def test_signal_conviction_default(): from shared.models import Signal, OrderSide signal = Signal( - strategy="rsi", symbol="BTCUSDT", side=OrderSide.BUY, - price=Decimal("50000"), quantity=Decimal("0.01"), reason="test", + strategy="rsi", + symbol="AAPL", + side=OrderSide.BUY, + price=Decimal("50000"), + quantity=Decimal("0.01"), + reason="test", ) assert signal.conviction == 1.0 assert signal.stop_loss is None @@ -112,8 +110,12 @@ def test_signal_with_stops(): from shared.models import Signal, OrderSide signal = Signal( - strategy="rsi", symbol="BTCUSDT", side=OrderSide.BUY, - price=Decimal("50000"), quantity=Decimal("0.01"), reason="test", + strategy="rsi", + symbol="AAPL", + side=OrderSide.BUY, + price=Decimal("50000"), + quantity=Decimal("0.01"), + reason="test", conviction=0.8, stop_loss=Decimal("48000"), take_profit=Decimal("55000"), @@ -128,7 +130,7 @@ def test_position_unrealized_pnl(): from shared.models import Position position = Position( - symbol="BTCUSDT", + symbol="AAPL", quantity=Decimal("0.1"), avg_entry_price=Decimal("50000"), current_price=Decimal("51000"), diff --git a/shared/tests/test_notifier.py b/shared/tests/test_notifier.py index 3d29830..6c81369 100644 --- a/shared/tests/test_notifier.py +++ b/shared/tests/test_notifier.py @@ -86,7 +86,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") signal = Signal( strategy="rsi_strategy", - symbol="BTCUSDT", + symbol="AAPL", side=OrderSide.BUY, price=Decimal("50000.00"), quantity=Decimal("0.01"), @@ -99,7 +99,7 @@ class TestTelegramNotifierFormatters: msg = mock_send.call_args[0][0] assert "BUY" in msg assert "rsi_strategy" in msg - assert "BTCUSDT" in msg + assert "AAPL" in msg assert "50000.00" in msg assert "0.01" in msg assert "RSI oversold" in msg @@ -109,7 +109,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") order = Order( signal_id=str(uuid.uuid4()), - symbol="ETHUSDT", + symbol="MSFT", side=OrderSide.SELL, type=OrderType.LIMIT, price=Decimal("3000.50"), @@ -122,7 +122,7 @@ class TestTelegramNotifierFormatters: mock_send.assert_called_once() msg = mock_send.call_args[0][0] assert "FILLED" in msg - assert "ETHUSDT" in msg + assert "MSFT" in msg assert "SELL" in msg assert "3000.50" in msg assert "1.5" in msg @@ -143,7 +143,7 @@ class TestTelegramNotifierFormatters: notifier = TelegramNotifier(bot_token="fake-token", chat_id="123") positions = [ Position( - symbol="BTCUSDT", + symbol="AAPL", quantity=Decimal("0.1"), avg_entry_price=Decimal("50000"), current_price=Decimal("51000"), @@ -158,7 +158,7 @@ class TestTelegramNotifierFormatters: ) mock_send.assert_called_once() msg = mock_send.call_args[0][0] - assert "BTCUSDT" in msg + assert "AAPL" in msg assert "5100.00" in msg assert "100.00" in msg diff --git a/shared/tests/test_sentiment.py b/shared/tests/test_sentiment.py new file mode 100644 index 0000000..9bd8ea3 --- /dev/null +++ b/shared/tests/test_sentiment.py @@ -0,0 +1,44 @@ +"""Tests for market sentiment module.""" + +from shared.sentiment import SentimentData + + +def test_sentiment_should_buy_default_no_data(): + s = SentimentData() + assert s.should_buy is True + assert s.should_block is False + + +def test_sentiment_should_buy_low_fear_greed(): + s = SentimentData(fear_greed_value=15) + assert s.should_buy is True + + +def test_sentiment_should_not_buy_on_greed(): + s = SentimentData(fear_greed_value=75) + assert s.should_buy is False + + +def test_sentiment_should_not_buy_negative_news(): + s = SentimentData(news_sentiment=-0.4) + assert s.should_buy is False + + +def test_sentiment_should_buy_positive_news(): + s = SentimentData(fear_greed_value=50, news_sentiment=0.3) + assert s.should_buy is True + + +def test_sentiment_should_block_extreme_greed(): + s = SentimentData(fear_greed_value=85) + assert s.should_block is True + + +def test_sentiment_should_block_very_negative_news(): + s = SentimentData(news_sentiment=-0.6) + assert s.should_block is True + + +def test_sentiment_no_block_on_neutral(): + s = SentimentData(fear_greed_value=50, news_sentiment=0.0) + assert s.should_block is False |
