summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-02 15:36:45 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-02 15:36:45 +0900
commite5fc21f3c9c890c254c5f74412aa0b68c3863042 (patch)
tree475f6ce445b9927c7c448ed3c3673c3d351e49ea /shared
parentbe7dc5311328d5d4bcb16cd613bcc88c26eaffa2 (diff)
feat: add config validation, SecretStr for secrets, API security fields
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/config.py40
-rw-r--r--shared/tests/test_config_validation.py29
-rw-r--r--shared/tests/test_models.py7
3 files changed, 67 insertions, 9 deletions
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
index b6b9d69..0f1c66e 100644
--- a/shared/src/shared/config.py
+++ b/shared/src/shared/config.py
@@ -1,14 +1,15 @@
"""Shared configuration settings for the trading platform."""
+from pydantic import SecretStr, field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
- alpaca_api_key: str = ""
- alpaca_api_secret: str = ""
+ alpaca_api_key: SecretStr = SecretStr("")
+ alpaca_api_secret: SecretStr = SecretStr("")
alpaca_paper: bool = True # Use paper trading by default
- redis_url: str = "redis://localhost:6379"
- database_url: str = "postgresql://trading:trading@localhost:5432/trading"
+ redis_url: SecretStr = SecretStr("redis://localhost:6379")
+ database_url: SecretStr = SecretStr("postgresql://trading:trading@localhost:5432/trading")
db_pool_size: int = 20
db_max_overflow: int = 10
db_pool_recycle: int = 3600
@@ -30,20 +31,45 @@ class Settings(BaseSettings):
risk_max_consecutive_losses: int = 5
risk_loss_pause_minutes: int = 60
dry_run: bool = True
- telegram_bot_token: str = ""
+ telegram_bot_token: SecretStr = SecretStr("")
telegram_chat_id: str = ""
telegram_enabled: bool = False
log_format: str = "json"
health_port: int = 8080
metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token
+ # API security
+ api_auth_token: SecretStr = SecretStr("")
+ cors_origins: str = "http://localhost:3000"
# News collector
- finnhub_api_key: str = ""
+ finnhub_api_key: SecretStr = SecretStr("")
news_poll_interval: int = 300
sentiment_aggregate_interval: int = 900
# Stock selector
selector_final_time: str = "15:30"
selector_max_picks: int = 3
# LLM
- anthropic_api_key: str = ""
+ anthropic_api_key: SecretStr = SecretStr("")
anthropic_model: str = "claude-sonnet-4-20250514"
model_config = {"env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}
+
+ @field_validator("risk_max_position_size")
+ @classmethod
+ def validate_position_size(cls, v: float) -> float:
+ if v <= 0 or v > 1:
+ raise ValueError("risk_max_position_size must be between 0 and 1 (exclusive)")
+ return v
+
+ @field_validator("health_port")
+ @classmethod
+ def validate_health_port(cls, v: int) -> int:
+ if v < 1024 or v > 65535:
+ raise ValueError("health_port must be between 1024 and 65535")
+ return v
+
+ @field_validator("log_level")
+ @classmethod
+ def validate_log_level(cls, v: str) -> str:
+ valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
+ if v.upper() not in valid:
+ raise ValueError(f"log_level must be one of {valid}")
+ return v.upper()
diff --git a/shared/tests/test_config_validation.py b/shared/tests/test_config_validation.py
new file mode 100644
index 0000000..9376dc6
--- /dev/null
+++ b/shared/tests/test_config_validation.py
@@ -0,0 +1,29 @@
+"""Tests for config validation."""
+
+import pytest
+from pydantic import ValidationError
+
+from shared.config import Settings
+
+
+class TestConfigValidation:
+ def test_valid_defaults(self):
+ settings = Settings()
+ assert settings.risk_max_position_size == 0.1
+
+ def test_invalid_position_size(self):
+ with pytest.raises(ValidationError, match="risk_max_position_size"):
+ Settings(risk_max_position_size=-0.1)
+
+ def test_invalid_health_port(self):
+ with pytest.raises(ValidationError, match="health_port"):
+ Settings(health_port=80)
+
+ def test_invalid_log_level(self):
+ with pytest.raises(ValidationError, match="log_level"):
+ Settings(log_level="INVALID")
+
+ def test_secret_fields_masked(self):
+ settings = Settings(alpaca_api_key="my-secret-key")
+ assert "my-secret-key" not in repr(settings)
+ assert settings.alpaca_api_key.get_secret_value() == "my-secret-key"
diff --git a/shared/tests/test_models.py b/shared/tests/test_models.py
index 04098ce..21f9831 100644
--- a/shared/tests/test_models.py
+++ b/shared/tests/test_models.py
@@ -12,8 +12,11 @@ def test_settings_defaults():
with patch.dict(os.environ, {}, clear=False):
settings = Settings()
- assert settings.redis_url == "redis://localhost:6379"
- assert settings.database_url == "postgresql://trading:trading@localhost:5432/trading"
+ assert settings.redis_url.get_secret_value() == "redis://localhost:6379"
+ assert (
+ settings.database_url.get_secret_value()
+ == "postgresql://trading:trading@localhost:5432/trading"
+ )
assert settings.log_level == "INFO"
assert settings.risk_max_position_size == 0.1
assert settings.risk_stop_loss_pct == 5.0