diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:36:45 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:36:45 +0900 |
| commit | e5fc21f3c9c890c254c5f74412aa0b68c3863042 (patch) | |
| tree | 475f6ce445b9927c7c448ed3c3673c3d351e49ea | |
| parent | be7dc5311328d5d4bcb16cd613bcc88c26eaffa2 (diff) | |
feat: add config validation, SecretStr for secrets, API security fields
| -rw-r--r-- | cli/src/trading_cli/commands/backtest.py | 4 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/data.py | 8 | ||||
| -rw-r--r-- | cli/src/trading_cli/commands/portfolio.py | 4 | ||||
| -rw-r--r-- | services/api/src/trading_api/main.py | 2 | ||||
| -rw-r--r-- | services/backtester/src/backtester/main.py | 2 | ||||
| -rw-r--r-- | services/data-collector/src/data_collector/main.py | 10 | ||||
| -rw-r--r-- | services/news-collector/src/news_collector/main.py | 8 | ||||
| -rw-r--r-- | services/order-executor/src/order_executor/main.py | 10 | ||||
| -rw-r--r-- | services/portfolio-manager/src/portfolio_manager/main.py | 6 | ||||
| -rw-r--r-- | services/strategy-engine/src/strategy_engine/main.py | 14 | ||||
| -rw-r--r-- | shared/src/shared/config.py | 40 | ||||
| -rw-r--r-- | shared/tests/test_config_validation.py | 29 | ||||
| -rw-r--r-- | shared/tests/test_models.py | 7 |
13 files changed, 101 insertions, 43 deletions
diff --git a/cli/src/trading_cli/commands/backtest.py b/cli/src/trading_cli/commands/backtest.py index 3876f1b..ad21f8f 100644 --- a/cli/src/trading_cli/commands/backtest.py +++ b/cli/src/trading_cli/commands/backtest.py @@ -58,7 +58,7 @@ def run(strategy, symbol, timeframe, balance, output_format, file_path): async def _run(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: candle_rows = await db.get_candles(symbol, timeframe, limit=500) @@ -131,7 +131,7 @@ def walk_forward(strategy, symbol, timeframe, balance, windows): async def _run(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: rows = await db.get_candles(symbol, timeframe, limit=2000) diff --git a/cli/src/trading_cli/commands/data.py b/cli/src/trading_cli/commands/data.py index 1ecc15f..8797564 100644 --- a/cli/src/trading_cli/commands/data.py +++ b/cli/src/trading_cli/commands/data.py @@ -49,7 +49,7 @@ def history(symbol, timeframe, since, limit): from datetime import datetime, timezone settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() start = None @@ -64,8 +64,8 @@ def history(symbol, timeframe, since, limit): sys.exit(1) client = AlpacaClient( - api_key=settings.alpaca_api_key, - api_secret=settings.alpaca_api_secret, + api_key=settings.alpaca_api_key.get_secret_value(), + api_secret=settings.alpaca_api_secret.get_secret_value(), base_url=getattr(settings, "alpaca_base_url", "https://paper-api.alpaca.markets"), ) @@ -107,7 +107,7 @@ def list_(): async def _list(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: stmt = ( diff --git a/cli/src/trading_cli/commands/portfolio.py b/cli/src/trading_cli/commands/portfolio.py index ad9a6b4..4f49894 100644 --- a/cli/src/trading_cli/commands/portfolio.py +++ b/cli/src/trading_cli/commands/portfolio.py @@ -27,7 +27,7 @@ def show(): async def _show(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: async with db.get_session() as session: @@ -81,7 +81,7 @@ def history(days): async def _history(): settings = Settings() - db = Database(settings.database_url) + db = Database(settings.database_url.get_secret_value()) await db.connect() try: since = datetime.now(timezone.utc) - timedelta(days=days) diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py index 39f7b43..87306b2 100644 --- a/services/api/src/trading_api/main.py +++ b/services/api/src/trading_api/main.py @@ -13,7 +13,7 @@ from trading_api.routers import portfolio, orders, strategies @asynccontextmanager async def lifespan(app: FastAPI): settings = Settings() - app.state.db = Database(settings.database_url) + app.state.db = Database(settings.database_url.get_secret_value()) await app.state.db.connect() yield await app.state.db.close() diff --git a/services/backtester/src/backtester/main.py b/services/backtester/src/backtester/main.py index a4cea76..084ce02 100644 --- a/services/backtester/src/backtester/main.py +++ b/services/backtester/src/backtester/main.py @@ -45,7 +45,7 @@ async def run_backtest() -> str: except Exception as exc: raise RuntimeError(f"Failed to load strategy '{config.strategy_name}': {exc}") from exc - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() try: rows = await db.get_candles(config.symbol, config.timeframe, config.candle_limit) diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py index b42b34c..608d6cd 100644 --- a/services/data-collector/src/data_collector/main.py +++ b/services/data-collector/src/data_collector/main.py @@ -56,18 +56,18 @@ async def run() -> None: metrics = ServiceMetrics("data_collector") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) diff --git a/services/news-collector/src/news_collector/main.py b/services/news-collector/src/news_collector/main.py index 3493f7c..f56914f 100644 --- a/services/news-collector/src/news_collector/main.py +++ b/services/news-collector/src/news_collector/main.py @@ -115,14 +115,14 @@ async def run() -> None: metrics = ServiceMetrics("news_collector") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) health = HealthCheckServer( "news-collector", @@ -133,7 +133,7 @@ async def run() -> None: metrics.service_up.labels(service="news-collector").set(1) # Build collectors - finnhub = FinnhubCollector(api_key=config.finnhub_api_key) + finnhub = FinnhubCollector(api_key=config.finnhub_api_key.get_secret_value()) rss = RSSCollector() sec = SecEdgarCollector() truth = TruthSocialCollector() diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py index 51ab286..1d167ef 100644 --- a/services/order-executor/src/order_executor/main.py +++ b/services/order-executor/src/order_executor/main.py @@ -26,18 +26,18 @@ async def run() -> None: metrics = ServiceMetrics("order_executor") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py index a6823ae..0214099 100644 --- a/services/portfolio-manager/src/portfolio_manager/main.py +++ b/services/portfolio-manager/src/portfolio_manager/main.py @@ -61,10 +61,10 @@ async def run() -> None: log = setup_logging("portfolio-manager", config.log_level, config.log_format) metrics = ServiceMetrics("portfolio_manager") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, chat_id=config.telegram_chat_id + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) tracker = PortfolioTracker() health = HealthCheckServer( @@ -76,7 +76,7 @@ async def run() -> None: await health.start() metrics.service_up.labels(service="portfolio-manager").set(1) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() snapshot_task = asyncio.create_task( diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py index 5a30766..2635b7c 100644 --- a/services/strategy-engine/src/strategy_engine/main.py +++ b/services/strategy-engine/src/strategy_engine/main.py @@ -76,18 +76,18 @@ async def run() -> None: metrics = ServiceMetrics("strategy_engine") notifier = TelegramNotifier( - bot_token=config.telegram_bot_token, + bot_token=config.telegram_bot_token.get_secret_value(), chat_id=config.telegram_chat_id, ) - broker = RedisBroker(config.redis_url) + broker = RedisBroker(config.redis_url.get_secret_value()) - db = Database(config.database_url) + db = Database(config.database_url.get_secret_value()) await db.connect() alpaca = AlpacaClient( - api_key=config.alpaca_api_key, - api_secret=config.alpaca_api_secret, + api_key=config.alpaca_api_key.get_secret_value(), + api_secret=config.alpaca_api_secret.get_secret_value(), paper=config.alpaca_paper, ) @@ -117,12 +117,12 @@ async def run() -> None: task = asyncio.create_task(process_symbol(engine, stream, log)) tasks.append(task) - if config.anthropic_api_key: + if config.anthropic_api_key.get_secret_value(): selector = StockSelector( db=db, broker=broker, alpaca=alpaca, - anthropic_api_key=config.anthropic_api_key, + anthropic_api_key=config.anthropic_api_key.get_secret_value(), anthropic_model=config.anthropic_model, max_picks=config.selector_max_picks, ) diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index b6b9d69..0f1c66e 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -1,14 +1,15 @@ """Shared configuration settings for the trading platform.""" +from pydantic import SecretStr, field_validator from pydantic_settings import BaseSettings class Settings(BaseSettings): - alpaca_api_key: str = "" - alpaca_api_secret: str = "" + alpaca_api_key: SecretStr = SecretStr("") + alpaca_api_secret: SecretStr = SecretStr("") alpaca_paper: bool = True # Use paper trading by default - redis_url: str = "redis://localhost:6379" - database_url: str = "postgresql://trading:trading@localhost:5432/trading" + redis_url: SecretStr = SecretStr("redis://localhost:6379") + database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading") db_pool_size: int = 20 db_max_overflow: int = 10 db_pool_recycle: int = 3600 @@ -30,20 +31,45 @@ class Settings(BaseSettings): risk_max_consecutive_losses: int = 5 risk_loss_pause_minutes: int = 60 dry_run: bool = True - telegram_bot_token: str = "" + telegram_bot_token: SecretStr = SecretStr("") telegram_chat_id: str = "" telegram_enabled: bool = False log_format: str = "json" health_port: int = 8080 metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token + # API security + api_auth_token: SecretStr = SecretStr("") + cors_origins: str = "http://localhost:3000" # News collector - finnhub_api_key: str = "" + finnhub_api_key: SecretStr = SecretStr("") news_poll_interval: int = 300 sentiment_aggregate_interval: int = 900 # Stock selector selector_final_time: str = "15:30" selector_max_picks: int = 3 # LLM - anthropic_api_key: str = "" + anthropic_api_key: SecretStr = SecretStr("") anthropic_model: str = "claude-sonnet-4-20250514" model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} + + @field_validator("risk_max_position_size") + @classmethod + def validate_position_size(cls, v: float) -> float: + if v <= 0 or v > 1: + raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)") + return v + + @field_validator("health_port") + @classmethod + def validate_health_port(cls, v: int) -> int: + if v < 1024 or v > 65535: + raise ValueError("health_port must be between 1024 and 65535") + return v + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + if v.upper() not in valid: + raise ValueError(f"log_level must be one of {valid}") + return v.upper() diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py new file mode 100644 index 0000000..9376dc6 --- /dev/null +++ b/shared/tests/test_config_validation.py @@ -0,0 +1,29 @@ +"""Tests for config validation.""" + +import pytest +from pydantic import ValidationError + +from shared.config import Settings + + +class TestConfigValidation: + def test_valid_defaults(self): + settings = Settings() + assert settings.risk_max_position_size == 0.1 + + def test_invalid_position_size(self): + with pytest.raises(ValidationError, match="risk_max_position_size"): + Settings(risk_max_position_size=-0.1) + + def test_invalid_health_port(self): + with pytest.raises(ValidationError, match="health_port"): + Settings(health_port=80) + + def test_invalid_log_level(self): + with pytest.raises(ValidationError, match="log_level"): + Settings(log_level="INVALID") + + def test_secret_fields_masked(self): + settings = Settings(alpaca_api_key="my-secret-key") + assert "my-secret-key" not in repr(settings) + assert settings.alpaca_api_key.get_secret_value() == "my-secret-key" diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py index 04098ce..21f9831 100644 --- a/shared/tests/test_models.py +++ b/shared/tests/test_models.py @@ -12,8 +12,11 @@ def test_settings_defaults(): with patch.dict(os.environ, {}, clear=False): settings = Settings() - assert settings.redis_url == "redis://localhost:6379" - assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading" + assert settings.redis_url.get_secret_value() == "redis://localhost:6379" + assert ( + settings.database_url.get_secret_value() + == "postgresql://trading:trading@localhost:5432/trading" + ) assert settings.log_level == "INFO" assert settings.risk_max_position_size == 0.1 assert settings.risk_stop_loss_pct == 5.0 |
