summaryrefslogtreecommitdiff
path: root/services/api/src
diff options
context:
space:
mode:
Diffstat (limited to 'services/api/src')
-rw-r--r--services/api/src/trading_api/dependencies/__init__.py0
-rw-r--r--services/api/src/trading_api/dependencies/auth.py29
-rw-r--r--services/api/src/trading_api/main.py49
-rw-r--r--services/api/src/trading_api/routers/orders.py11
-rw-r--r--services/api/src/trading_api/routers/portfolio.py4
5 files changed, 83 insertions, 10 deletions
diff --git a/services/api/src/trading_api/dependencies/__init__.py b/services/api/src/trading_api/dependencies/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/api/src/trading_api/dependencies/__init__.py
diff --git a/services/api/src/trading_api/dependencies/auth.py b/services/api/src/trading_api/dependencies/auth.py
new file mode 100644
index 0000000..a5e76c1
--- /dev/null
+++ b/services/api/src/trading_api/dependencies/auth.py
@@ -0,0 +1,29 @@
+"""Bearer token authentication dependency."""
+
+import logging
+
+from fastapi import Depends, HTTPException, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+
+from shared.config import Settings
+
+logger = logging.getLogger(__name__)
+
+_security = HTTPBearer(auto_error=False)
+_settings = Settings()
+
+
+async def verify_token(
+ credentials: HTTPAuthorizationCredentials | None = Depends(_security),
+) -> None:
+ """Verify Bearer token. Skip auth if API_AUTH_TOKEN is not configured."""
+ token = _settings.api_auth_token.get_secret_value()
+ if not token:
+ return # Auth disabled in dev mode
+
+ if credentials is None or credentials.credentials != token:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or missing authentication token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py
index 87306b2..090b110 100644
--- a/services/api/src/trading_api/main.py
+++ b/services/api/src/trading_api/main.py
@@ -1,33 +1,72 @@
"""Trading Platform REST API."""
+import logging
from contextlib import asynccontextmanager
-from fastapi import FastAPI
+from fastapi import Depends, FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+from slowapi.util import get_remote_address
from shared.config import Settings
from shared.db import Database
-from trading_api.routers import portfolio, orders, strategies
+from trading_api.dependencies.auth import verify_token
+from trading_api.routers import orders, portfolio, strategies
+
+logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
+ if not settings.api_auth_token.get_secret_value():
+ logger.warning("API_AUTH_TOKEN not set — authentication is disabled")
app.state.db = Database(settings.database_url.get_secret_value())
await app.state.db.connect()
yield
await app.state.db.close()
+cfg = Settings()
+
+limiter = Limiter(key_func=get_remote_address)
+
app = FastAPI(
title="Trading Platform API",
version="0.1.0",
lifespan=lifespan,
)
-app.include_router(portfolio.router, prefix="/api/v1/portfolio", tags=["portfolio"])
-app.include_router(orders.router, prefix="/api/v1/orders", tags=["orders"])
-app.include_router(strategies.router, prefix="/api/v1/strategies", tags=["strategies"])
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=cfg.cors_origins.split(","),
+ allow_methods=["GET", "POST"],
+ allow_headers=["Authorization", "Content-Type"],
+)
+
+app.include_router(
+ portfolio.router,
+ prefix="/api/v1/portfolio",
+ tags=["portfolio"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ orders.router,
+ prefix="/api/v1/orders",
+ tags=["orders"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ strategies.router,
+ prefix="/api/v1/strategies",
+ tags=["strategies"],
+ dependencies=[Depends(verify_token)],
+)
@app.get("/health")
diff --git a/services/api/src/trading_api/routers/orders.py b/services/api/src/trading_api/routers/orders.py
index a29ae2f..217efef 100644
--- a/services/api/src/trading_api/routers/orders.py
+++ b/services/api/src/trading_api/routers/orders.py
@@ -2,18 +2,22 @@
import logging
-from fastapi import APIRouter, HTTPException, Request
+from fastapi import APIRouter, HTTPException, Query, Request
from shared.sa_models import OrderRow, SignalRow
+from slowapi import Limiter
+from slowapi.util import get_remote_address
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
logger = logging.getLogger(__name__)
router = APIRouter()
+limiter = Limiter(key_func=get_remote_address)
@router.get("/")
-async def get_orders(request: Request, limit: int = 50):
+@limiter.limit("60/minute")
+async def get_orders(request: Request, limit: int = Query(50, ge=1, le=1000)):
"""Get recent orders."""
try:
db = request.app.state.db
@@ -45,7 +49,8 @@ async def get_orders(request: Request, limit: int = 50):
@router.get("/signals")
-async def get_signals(request: Request, limit: int = 50):
+@limiter.limit("60/minute")
+async def get_signals(request: Request, limit: int = Query(50, ge=1, le=1000)):
"""Get recent signals."""
try:
db = request.app.state.db
diff --git a/services/api/src/trading_api/routers/portfolio.py b/services/api/src/trading_api/routers/portfolio.py
index 3907a86..fde90cb 100644
--- a/services/api/src/trading_api/routers/portfolio.py
+++ b/services/api/src/trading_api/routers/portfolio.py
@@ -2,7 +2,7 @@
import logging
-from fastapi import APIRouter, HTTPException, Request
+from fastapi import APIRouter, HTTPException, Query, Request
from shared.sa_models import PositionRow
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
@@ -39,7 +39,7 @@ async def get_positions(request: Request):
@router.get("/snapshots")
-async def get_snapshots(request: Request, days: int = 30):
+async def get_snapshots(request: Request, days: int = Query(30, ge=1, le=365)):
"""Get portfolio snapshots for the last N days."""
try:
db = request.app.state.db