summaryrefslogtreecommitdiff
path: root/services/api/src/trading_api/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'services/api/src/trading_api/main.py')
-rw-r--r--services/api/src/trading_api/main.py50
1 files changed, 44 insertions, 6 deletions
diff --git a/services/api/src/trading_api/main.py b/services/api/src/trading_api/main.py
index 39f7b43..05c6d2f 100644
--- a/services/api/src/trading_api/main.py
+++ b/services/api/src/trading_api/main.py
@@ -1,33 +1,71 @@
"""Trading Platform REST API."""
+import logging
from contextlib import asynccontextmanager
-from fastapi import FastAPI
+from fastapi import Depends, FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+from slowapi.util import get_remote_address
from shared.config import Settings
from shared.db import Database
+from trading_api.dependencies.auth import verify_token
+from trading_api.routers import orders, portfolio, strategies
-from trading_api.routers import portfolio, orders, strategies
+logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
- app.state.db = Database(settings.database_url)
+ if not settings.api_auth_token.get_secret_value():
+ logger.warning("API_AUTH_TOKEN not set — authentication is disabled")
+ app.state.db = Database(settings.database_url.get_secret_value())
await app.state.db.connect()
yield
await app.state.db.close()
+cfg = Settings()
+
+limiter = Limiter(key_func=get_remote_address)
+
app = FastAPI(
title="Trading Platform API",
version="0.1.0",
lifespan=lifespan,
)
-app.include_router(portfolio.router, prefix="/api/v1/portfolio", tags=["portfolio"])
-app.include_router(orders.router, prefix="/api/v1/orders", tags=["orders"])
-app.include_router(strategies.router, prefix="/api/v1/strategies", tags=["strategies"])
+app.state.limiter = limiter
+app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=cfg.cors_origins.split(","),
+ allow_methods=["GET", "POST"],
+ allow_headers=["Authorization", "Content-Type"],
+)
+
+app.include_router(
+ portfolio.router,
+ prefix="/api/v1/portfolio",
+ tags=["portfolio"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ orders.router,
+ prefix="/api/v1/orders",
+ tags=["orders"],
+ dependencies=[Depends(verify_token)],
+)
+app.include_router(
+ strategies.router,
+ prefix="/api/v1/strategies",
+ tags=["strategies"],
+ dependencies=[Depends(verify_token)],
+)
@app.get("/health")